forked from TrueCloudLab/certificates
Merge pull request #161 from smallstep/unittests
Introduce generalized statusCoder errors and loads of ssh unit tests.
This commit is contained in:
commit
f3f8ee4207
88 changed files with 5620 additions and 2544 deletions
|
@ -63,6 +63,7 @@ issues:
|
|||
- declaration of "err" shadows declaration at line
|
||||
- should have a package comment, unless it's in another file for this package
|
||||
- error strings should not be capitalized or end with punctuation or a newline
|
||||
- Wrapf call needs 1 arg but has 2 args
|
||||
# golangci.com configuration
|
||||
# https://github.com/golangci/golangci/wiki/Configuration
|
||||
service:
|
||||
|
|
120
api/api.go
120
api/api.go
|
@ -5,7 +5,6 @@ import (
|
|||
"crypto/dsa"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
|
@ -209,14 +208,6 @@ type RootResponse struct {
|
|||
RootPEM Certificate `json:"ca"`
|
||||
}
|
||||
|
||||
// SignRequest is the request body for a certificate signature request.
|
||||
type SignRequest struct {
|
||||
CsrPEM CertificateRequest `json:"csr"`
|
||||
OTT string `json:"ott"`
|
||||
NotAfter TimeDuration `json:"notAfter"`
|
||||
NotBefore TimeDuration `json:"notBefore"`
|
||||
}
|
||||
|
||||
// ProvisionersResponse is the response object that returns the list of
|
||||
// provisioners.
|
||||
type ProvisionersResponse struct {
|
||||
|
@ -230,31 +221,6 @@ type ProvisionerKeyResponse struct {
|
|||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
// Validate checks the fields of the SignRequest and returns nil if they are ok
|
||||
// or an error if something is wrong.
|
||||
func (s *SignRequest) Validate() error {
|
||||
if s.CsrPEM.CertificateRequest == nil {
|
||||
return errs.BadRequest(errors.New("missing csr"))
|
||||
}
|
||||
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
|
||||
return errs.BadRequest(errors.Wrap(err, "invalid csr"))
|
||||
}
|
||||
if s.OTT == "" {
|
||||
return errs.BadRequest(errors.New("missing ott"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignResponse is the response object of the certificate signature request.
|
||||
type SignResponse struct {
|
||||
ServerPEM Certificate `json:"crt"`
|
||||
CaPEM Certificate `json:"ca"`
|
||||
CertChainPEM []Certificate `json:"certChain"`
|
||||
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
|
||||
TLS *tls.ConnectionState `json:"-"`
|
||||
}
|
||||
|
||||
// RootsResponse is the response object of the roots request.
|
||||
type RootsResponse struct {
|
||||
Certificates []Certificate `json:"crts"`
|
||||
|
@ -329,7 +295,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
|||
// Load root certificate with the
|
||||
cert, err := h.Authority.Root(sum)
|
||||
if err != nil {
|
||||
WriteError(w, errs.NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI)))
|
||||
WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -344,91 +310,17 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
|||
return certChainPEM
|
||||
}
|
||||
|
||||
// Sign is an HTTP handler that reads a certificate request and an
|
||||
// one-time-token (ott) from the body and creates a new certificate with the
|
||||
// information in the certificate request.
|
||||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
opts := provisioner.Options{
|
||||
NotBefore: body.NotBefore,
|
||||
NotAfter: body.NotAfter,
|
||||
}
|
||||
|
||||
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 0 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
logCertificate(w, certChain[0])
|
||||
JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
WriteError(w, errs.BadRequest(errors.New("missing peer certificate")))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 0 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
|
||||
logCertificate(w, certChain[0])
|
||||
JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Provisioners returns the list of provisioners configured in the authority.
|
||||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := parseCursor(r)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := h.Authority.GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
JSON(w, &ProvisionersResponse{
|
||||
|
@ -442,7 +334,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
|||
kid := chi.URLParam(r, "kid")
|
||||
key, err := h.Authority.GetEncryptedKey(kid)
|
||||
if err != nil {
|
||||
WriteError(w, errs.NotFound(err))
|
||||
WriteError(w, errs.NotFoundErr(err))
|
||||
return
|
||||
}
|
||||
JSON(w, &ProvisionerKeyResponse{key})
|
||||
|
@ -452,7 +344,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := h.Authority.GetRoots()
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -470,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := h.Authority.GetFederation()
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/sshutil"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
|
@ -914,7 +915,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
|
||||
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
||||
}
|
||||
|
||||
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
|
@ -934,13 +935,13 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.Root unexpected error = %v", err)
|
||||
t.Errorf("caHandler.Renew unexpected error = %v", err)
|
||||
}
|
||||
if tt.statusCode < http.StatusBadRequest {
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||||
|
@ -1009,8 +1010,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedError400 := []byte(`{"status":400,"message":"Bad Request"}`)
|
||||
expectedError500 := []byte(`{"status":500,"message":"Internal Server Error"}`)
|
||||
expectedError400 := errs.BadRequest("force")
|
||||
expectedError400Bytes, err := json.Marshal(expectedError400)
|
||||
assert.FatalError(t, err)
|
||||
expectedError500 := errs.InternalServer("force")
|
||||
expectedError500Bytes, err := json.Marshal(expectedError500)
|
||||
assert.FatalError(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &caHandler{
|
||||
|
@ -1035,12 +1040,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
} else {
|
||||
switch tt.statusCode {
|
||||
case 400:
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expectedError400) {
|
||||
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400)
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expectedError400Bytes) {
|
||||
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400Bytes)
|
||||
}
|
||||
case 500:
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expectedError500) {
|
||||
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500)
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expectedError500Bytes) {
|
||||
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500Bytes)
|
||||
}
|
||||
default:
|
||||
t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode)
|
||||
|
@ -1077,7 +1082,9 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
|||
}
|
||||
|
||||
expected := []byte(`{"key":"` + privKey + `"}`)
|
||||
expectedError := []byte(`{"status":404,"message":"Not Found"}`)
|
||||
expectedError404 := errs.NotFound("force")
|
||||
expectedError404Bytes, err := json.Marshal(expectedError404)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -1101,8 +1108,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
|||
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected)
|
||||
}
|
||||
} else {
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expectedError) {
|
||||
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError)
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expectedError404Bytes) {
|
||||
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError404Bytes)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
35
api/renew.go
Normal file
35
api/renew.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
WriteError(w, errs.BadRequest("missing peer certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
|
||||
if err != nil {
|
||||
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 1 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
|
||||
logCertificate(w, certChain[0])
|
||||
JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
@ -30,13 +29,13 @@ type RevokeRequest struct {
|
|||
// or an error if something is wrong.
|
||||
func (r *RevokeRequest) Validate() (err error) {
|
||||
if r.Serial == "" {
|
||||
return errs.BadRequest(errors.New("missing serial"))
|
||||
return errs.BadRequest("missing serial")
|
||||
}
|
||||
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
|
||||
return errs.BadRequest(errors.New("reasonCode out of bounds"))
|
||||
return errs.BadRequest("reasonCode out of bounds")
|
||||
}
|
||||
if !r.Passive {
|
||||
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
|
||||
return errs.NotImplemented("non-passive revocation not implemented")
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -50,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) {
|
|||
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body RevokeRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -72,7 +71,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
if len(body.OTT) > 0 {
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
@ -81,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
// the client certificate Serial Number must match the serial number
|
||||
// being revoked.
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
WriteError(w, errs.BadRequest(errors.New("missing ott or peer certificate")))
|
||||
WriteError(w, errs.BadRequest("missing ott or peer certificate"))
|
||||
return
|
||||
}
|
||||
opts.Crt = r.TLS.PeerCertificates[0]
|
||||
if opts.Crt.SerialNumber.String() != opts.Serial {
|
||||
WriteError(w, errs.BadRequest(errors.New("revoke: serial number in mtls certificate different than body")))
|
||||
WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body"))
|
||||
return
|
||||
}
|
||||
// TODO: should probably be checking if the certificate was revoked here.
|
||||
|
@ -97,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
return errs.InternalServerError(errors.New("force"))
|
||||
return errs.InternalServer("force")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
89
api/sign.go
Normal file
89
api/sign.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
)
|
||||
|
||||
// SignRequest is the request body for a certificate signature request.
|
||||
type SignRequest struct {
|
||||
CsrPEM CertificateRequest `json:"csr"`
|
||||
OTT string `json:"ott"`
|
||||
NotAfter TimeDuration `json:"notAfter"`
|
||||
NotBefore TimeDuration `json:"notBefore"`
|
||||
}
|
||||
|
||||
// Validate checks the fields of the SignRequest and returns nil if they are ok
|
||||
// or an error if something is wrong.
|
||||
func (s *SignRequest) Validate() error {
|
||||
if s.CsrPEM.CertificateRequest == nil {
|
||||
return errs.BadRequest("missing csr")
|
||||
}
|
||||
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
|
||||
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
|
||||
}
|
||||
if s.OTT == "" {
|
||||
return errs.BadRequest("missing ott")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignResponse is the response object of the certificate signature request.
|
||||
type SignResponse struct {
|
||||
ServerPEM Certificate `json:"crt"`
|
||||
CaPEM Certificate `json:"ca"`
|
||||
CertChainPEM []Certificate `json:"certChain"`
|
||||
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
|
||||
TLS *tls.ConnectionState `json:"-"`
|
||||
}
|
||||
|
||||
// Sign is an HTTP handler that reads a certificate request and an
|
||||
// one-time-token (ott) from the body and creates a new certificate with the
|
||||
// information in the certificate request.
|
||||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
opts := provisioner.Options{
|
||||
NotBefore: body.NotBefore,
|
||||
NotAfter: body.NotAfter,
|
||||
}
|
||||
|
||||
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 1 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
logCertificate(w, certChain[0])
|
||||
JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
48
api/ssh.go
48
api/ssh.go
|
@ -249,19 +249,19 @@ type SSHBastionResponse struct {
|
|||
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHSignRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -269,7 +269,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
if body.AddUserPublicKey != nil {
|
||||
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -282,16 +282,16 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
ValidAfter: body.ValidAfter,
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -299,7 +299,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
|
||||
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
addUserCertificate = &SSHCertificate{addUserCert}
|
||||
|
@ -320,12 +320,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
certChain, err := h.Authority.Sign(cr, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
identityCertificate = certChainToPEM(certChain)
|
||||
|
@ -343,12 +343,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHRoots()
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
WriteError(w, errs.NotFound(errors.New("no keys found")))
|
||||
WriteError(w, errs.NotFound("no keys found"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -368,12 +368,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHFederation()
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
WriteError(w, errs.NotFound(errors.New("no keys found")))
|
||||
WriteError(w, errs.NotFound("no keys found"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -393,17 +393,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHConfigRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -414,7 +414,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
case provisioner.SSHHostCert:
|
||||
config.HostTemplates = ts
|
||||
default:
|
||||
WriteError(w, errs.InternalServerError(errors.New("it should hot get here")))
|
||||
WriteError(w, errs.InternalServer("it should hot get here"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -429,13 +429,13 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
JSON(w, &SSHCheckPrincipalResponse{
|
||||
|
@ -452,7 +452,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
hosts, err := h.Authority.GetSSHHosts(cert)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
JSON(w, &SSHGetHostsResponse{
|
||||
|
@ -464,17 +464,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHBastionRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -40,42 +40,42 @@ type SSHRekeyResponse struct {
|
|||
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRekeyRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RekeySSHMethod)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
}
|
||||
|
||||
newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -36,36 +36,36 @@ type SSHRenewResponse struct {
|
|||
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRenewRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RenewSSHMethod)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod)
|
||||
_, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
}
|
||||
|
||||
newCert, err := h.Authority.RenewSSH(oldCert)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
@ -30,16 +29,16 @@ type SSHRevokeRequest struct {
|
|||
// or an error if something is wrong.
|
||||
func (r *SSHRevokeRequest) Validate() (err error) {
|
||||
if r.Serial == "" {
|
||||
return errs.BadRequest(errors.New("missing serial"))
|
||||
return errs.BadRequest("missing serial")
|
||||
}
|
||||
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
|
||||
return errs.BadRequest(errors.New("reasonCode out of bounds"))
|
||||
return errs.BadRequest("reasonCode out of bounds")
|
||||
}
|
||||
if !r.Passive {
|
||||
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
|
||||
return errs.NotImplemented("non-passive revocation not implemented")
|
||||
}
|
||||
if len(r.OTT) == 0 {
|
||||
return errs.BadRequest(errors.New("missing ott"))
|
||||
return errs.BadRequest("missing ott")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -50,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
|||
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRevokeRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -66,18 +65,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
|||
PassiveOnly: body.Passive,
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeSSHMethod)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod)
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
@ -69,7 +68,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
|||
// pointed by v.
|
||||
func ReadJSON(r io.Reader, v interface{}) error {
|
||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||
return errs.BadRequest(errors.Wrap(err, "error decoding json"))
|
||||
return errs.Wrap(http.StatusBadRequest, err, "error decoding json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -13,12 +13,13 @@ import (
|
|||
stepJOSE "github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func testAuthority(t *testing.T) *Authority {
|
||||
func testAuthority(t *testing.T, opts ...Option) *Authority {
|
||||
maxjwk, err := stepJOSE.ParseKey("testdata/secrets/max_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
disableRenewal := true
|
||||
enableSSHCA := true
|
||||
p := provisioner.List{
|
||||
&provisioner.JWK{
|
||||
Name: "Max",
|
||||
|
@ -29,6 +30,9 @@ func testAuthority(t *testing.T) *Authority {
|
|||
Name: "step-cli",
|
||||
Type: "JWK",
|
||||
Key: clijwk,
|
||||
Claims: &provisioner.Claims{
|
||||
EnableSSHCA: &enableSSHCA,
|
||||
},
|
||||
},
|
||||
&provisioner.JWK{
|
||||
Name: "dev",
|
||||
|
@ -46,19 +50,30 @@ func testAuthority(t *testing.T) *Authority {
|
|||
DisableRenewal: &disableRenewal,
|
||||
},
|
||||
},
|
||||
&provisioner.SSHPOP{
|
||||
Name: "sshpop",
|
||||
Type: "SSHPOP",
|
||||
Claims: &provisioner.Claims{
|
||||
EnableSSHCA: &enableSSHCA,
|
||||
},
|
||||
},
|
||||
}
|
||||
c := &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: []string{"testdata/certs/root_ca.crt"},
|
||||
IntermediateCert: "testdata/certs/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.ca.smallstep.com"},
|
||||
SSH: &SSHConfig{
|
||||
HostKey: "testdata/secrets/ssh_host_ca_key",
|
||||
UserKey: "testdata/secrets/ssh_user_ca_key",
|
||||
},
|
||||
DNSNames: []string{"example.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: &AuthConfig{
|
||||
Provisioners: p,
|
||||
},
|
||||
}
|
||||
a, err := New(c)
|
||||
a, err := New(c, opts...)
|
||||
assert.FatalError(t, err)
|
||||
return a
|
||||
}
|
||||
|
|
|
@ -6,9 +6,10 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Claims extends jose.Claims with step attributes.
|
||||
|
@ -36,22 +37,19 @@ func SkipTokenReuseFromContext(ctx context.Context) bool {
|
|||
// authorizeToken parses the token and returns the provisioner used to generate
|
||||
// the token. This method enforces the One-Time use policy (tokens can only be
|
||||
// used once).
|
||||
func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner.Interface, error) {
|
||||
var errContext = map[string]interface{}{"ott": ott}
|
||||
|
||||
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
|
||||
// Validate payload
|
||||
token, err := jose.ParseSigned(ott)
|
||||
tok, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrapf(err, "authorizeToken: error parsing token"),
|
||||
http.StatusUnauthorized, errContext}
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token")
|
||||
}
|
||||
|
||||
// Get claims w/out verification. We need to look up the provisioner
|
||||
// key in order to verify the claims and we need the issuer from the claims
|
||||
// before we can look up the provisioner.
|
||||
var claims Claims
|
||||
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "authorizeToken"), http.StatusUnauthorized, errContext}
|
||||
if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
|
||||
}
|
||||
|
||||
// TODO: use new persistence layer abstraction.
|
||||
|
@ -59,29 +57,27 @@ func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner
|
|||
// This check is meant as a stopgap solution to the current lack of a persistence layer.
|
||||
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
|
||||
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
|
||||
return nil, &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"),
|
||||
http.StatusUnauthorized, errContext}
|
||||
return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority")
|
||||
}
|
||||
}
|
||||
|
||||
// This method will also validate the audiences for JWK provisioners.
|
||||
p, ok := a.provisioners.LoadByToken(token, &claims.Claims)
|
||||
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
|
||||
if !ok {
|
||||
return nil, &apiError{
|
||||
errors.Errorf("authorizeToken: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")),
|
||||
http.StatusUnauthorized, errContext}
|
||||
return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
|
||||
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
|
||||
}
|
||||
|
||||
// Store the token to protect against reuse unless it's skipped.
|
||||
if !SkipTokenReuseFromContext(ctx) {
|
||||
if reuseKey, err := p.GetTokenID(ott); err == nil {
|
||||
ok, err := a.db.UseToken(reuseKey, ott)
|
||||
if reuseKey, err := p.GetTokenID(token); err == nil {
|
||||
ok, err := a.db.UseToken(reuseKey, token)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "authorizeToken: failed when attempting to store token"),
|
||||
http.StatusInternalServerError, errContext}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.authorizeToken: failed when attempting to store token")
|
||||
}
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("authorizeToken: token already used"), http.StatusUnauthorized, errContext}
|
||||
return nil, errs.Unauthorized("authority.authorizeToken: token already used")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -89,125 +85,158 @@ func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner
|
|||
return p, nil
|
||||
}
|
||||
|
||||
// Authorize grabs the method from the context and authorizes a signature
|
||||
// request by validating the one-time-token.
|
||||
func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
var errContext = apiCtx{"ott": ott}
|
||||
// Authorize grabs the method from the context and authorizes the request by
|
||||
// validating the one-time-token.
|
||||
func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
var opts = []interface{}{errs.WithKeyVal("token", token)}
|
||||
|
||||
switch m := provisioner.MethodFromContext(ctx); m {
|
||||
case provisioner.SignMethod:
|
||||
return a.authorizeSign(ctx, ott)
|
||||
signOpts, err := a.authorizeSign(ctx, token)
|
||||
return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
|
||||
case provisioner.RevokeMethod:
|
||||
return nil, a.authorizeRevoke(ctx, ott)
|
||||
case provisioner.SignSSHMethod:
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...)
|
||||
case provisioner.SSHSignMethod:
|
||||
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext}
|
||||
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
|
||||
}
|
||||
return a.authorizeSSHSign(ctx, ott)
|
||||
case provisioner.RenewSSHMethod:
|
||||
_, err := a.authorizeSSHSign(ctx, token)
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
|
||||
case provisioner.SSHRenewMethod:
|
||||
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext}
|
||||
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
|
||||
}
|
||||
if _, err := a.authorizeSSHRenew(ctx, ott); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
case provisioner.RevokeSSHMethod:
|
||||
return nil, a.authorizeSSHRevoke(ctx, ott)
|
||||
case provisioner.RekeySSHMethod:
|
||||
_, err := a.authorizeSSHRenew(ctx, token)
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
|
||||
case provisioner.SSHRevokeMethod:
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...)
|
||||
case provisioner.SSHRekeyMethod:
|
||||
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext}
|
||||
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
|
||||
}
|
||||
_, opts, err := a.authorizeSSHRekey(ctx, ott)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return opts, nil
|
||||
_, signOpts, err := a.authorizeSSHRekey(ctx, token)
|
||||
return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
|
||||
default:
|
||||
return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext}
|
||||
return nil, errs.InternalServer("authority.Authorize; method %d is not supported", append([]interface{}{m}, opts...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeSign loads the provisioner from the token, checks that it has not
|
||||
// been used again and calls the provisioner AuthorizeSign method. Returns a
|
||||
// list of methods to apply to the signing flow.
|
||||
func (a *Authority) authorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
var errContext = apiCtx{"ott": ott}
|
||||
p, err := a.authorizeToken(ctx, ott)
|
||||
// authorizeSign loads the provisioner from the token and calls the provisioner
|
||||
// AuthorizeSign method. Returns a list of methods to apply to the signing flow.
|
||||
func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
p, err := a.authorizeToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign")
|
||||
}
|
||||
opts, err := p.AuthorizeSign(ctx, ott)
|
||||
signOpts, err := p.AuthorizeSign(ctx, token)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign")
|
||||
}
|
||||
return opts, nil
|
||||
return signOpts, nil
|
||||
}
|
||||
|
||||
// AuthorizeSign authorizes a signature request by validating and authenticating
|
||||
// a OTT that must be sent w/ the request.
|
||||
// a token that must be sent w/ the request.
|
||||
//
|
||||
// NOTE: This method is deprecated and should not be used. We make it available
|
||||
// in the short term os as not to break existing clients.
|
||||
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||
func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
return a.Authorize(ctx, ott)
|
||||
return a.Authorize(ctx, token)
|
||||
}
|
||||
|
||||
// authorizeRevoke authorizes a revocation request by validating and authenticating
|
||||
// the RevokeOptions POSTed with the request.
|
||||
// Returns a tuple of the provisioner ID and error, if one occurred.
|
||||
// authorizeRevoke locates the provisioner used to generate the authenticating
|
||||
// token and then performs the token validation flow.
|
||||
func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
||||
errContext := map[string]interface{}{"ott": token}
|
||||
|
||||
p, err := a.authorizeToken(ctx, token)
|
||||
if err != nil {
|
||||
return &apiError{errors.Wrap(err, "authorizeRevoke"), http.StatusUnauthorized, errContext}
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
||||
}
|
||||
if err = p.AuthorizeRevoke(ctx, token); err != nil {
|
||||
return &apiError{errors.Wrap(err, "authorizeRevoke"), http.StatusUnauthorized, errContext}
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorizeRenewl tries to locate the step provisioner extension, and checks
|
||||
// authorizeRenew locates the provisioner (using the provisioner extension in the cert), and checks
|
||||
// if for the configured provisioner, the renewal is enabled or not. If the
|
||||
// extra extension cannot be found, authorize the renewal by default.
|
||||
//
|
||||
// TODO(mariano): should we authorize by default?
|
||||
func (a *Authority) authorizeRenew(crt *x509.Certificate) error {
|
||||
errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()}
|
||||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
|
||||
|
||||
// Check the passive revocation table.
|
||||
isRevoked, err := a.db.IsRevoked(crt.SerialNumber.String())
|
||||
isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String())
|
||||
if err != nil {
|
||||
return &apiError{
|
||||
err: errors.Wrap(err, "renew"),
|
||||
code: http.StatusInternalServerError,
|
||||
context: errContext,
|
||||
}
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||
}
|
||||
if isRevoked {
|
||||
return &apiError{
|
||||
err: errors.New("renew: certificate has been revoked"),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
|
||||
}
|
||||
|
||||
p, ok := a.provisioners.LoadByCertificate(crt)
|
||||
p, ok := a.provisioners.LoadByCertificate(cert)
|
||||
if !ok {
|
||||
return &apiError{
|
||||
err: errors.New("renew: provisioner not found"),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
if err := p.AuthorizeRenew(context.Background(), crt); err != nil {
|
||||
return &apiError{
|
||||
err: errors.Wrap(err, "renew"),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
||||
}
|
||||
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorizeSSHSign loads the provisioner from the token, checks that it has not
|
||||
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a
|
||||
// list of methods to apply to the signing flow.
|
||||
func (a *Authority) authorizeSSHSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
p, err := a.authorizeToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign")
|
||||
}
|
||||
signOpts, err := p.AuthorizeSSHSign(ctx, token)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign")
|
||||
}
|
||||
return signOpts, nil
|
||||
}
|
||||
|
||||
// authorizeSSHRenew authorizes an SSH certificate renewal request, by
|
||||
// validating the contents of an SSHPOP token.
|
||||
func (a *Authority) authorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
|
||||
p, err := a.authorizeToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew")
|
||||
}
|
||||
cert, err := p.AuthorizeSSHRenew(ctx, token)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew")
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// authorizeSSHRekey authorizes an SSH certificate rekey request, by
|
||||
// validating the contents of an SSHPOP token.
|
||||
func (a *Authority) authorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) {
|
||||
p, err := a.authorizeToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey")
|
||||
}
|
||||
cert, signOpts, err := p.AuthorizeSSHRekey(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey")
|
||||
}
|
||||
return cert, signOpts, nil
|
||||
}
|
||||
|
||||
// authorizeSSHRevoke authorizes an SSH certificate revoke request, by
|
||||
// validating the contents of an SSHPOP token.
|
||||
func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
p, err := a.authorizeToken(ctx, token)
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke")
|
||||
}
|
||||
if err = p.AuthorizeSSHRevoke(ctx, token); err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,7 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -9,7 +10,6 @@ import (
|
|||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
stepJOSE "github.com/smallstep/cli/jose"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
|
@ -255,28 +255,6 @@ func TestAuthConfigValidate(t *testing.T) {
|
|||
err: errors.New("authority cannot be undefined"),
|
||||
}
|
||||
},
|
||||
"fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest {
|
||||
return AuthConfigValidateTest{
|
||||
ac: &AuthConfig{
|
||||
Provisioners: provisioner.List{
|
||||
&provisioner.JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
||||
&provisioner.JWK{Name: "foo", Key: &jose.JSONWebKey{}},
|
||||
},
|
||||
},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-invalid-claims": func(t *testing.T) AuthConfigValidateTest {
|
||||
return AuthConfigValidateTest{
|
||||
ac: &AuthConfig{
|
||||
Provisioners: p,
|
||||
Claims: &provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: -1},
|
||||
},
|
||||
},
|
||||
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
|
||||
}
|
||||
},
|
||||
"ok-empty-provisioners": func(t *testing.T) AuthConfigValidateTest {
|
||||
return AuthConfigValidateTest{
|
||||
ac: &AuthConfig{},
|
||||
|
@ -311,7 +289,7 @@ func TestAuthConfigValidate(t *testing.T) {
|
|||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
if assert.Nil(t, tc.err, fmt.Sprintf("expected error: %s, but got <nil>", tc.err)) {
|
||||
assert.Equals(t, *tc.ac.Template, tc.asn1dn)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,96 +0,0 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
|
||||
"github.com/smallstep/certificates/db"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type MockAuthDB struct {
|
||||
err error
|
||||
ret1 interface{}
|
||||
isRevoked func(string) (bool, error)
|
||||
isSSHRevoked func(string) (bool, error)
|
||||
revoke func(rci *db.RevokedCertificateInfo) error
|
||||
revokeSSH func(rci *db.RevokedCertificateInfo) error
|
||||
storeCertificate func(crt *x509.Certificate) error
|
||||
useToken func(id, tok string) (bool, error)
|
||||
isSSHHost func(principal string) (bool, error)
|
||||
storeSSHCertificate func(crt *ssh.Certificate) error
|
||||
getSSHHostPrincipals func() ([]string, error)
|
||||
shutdown func() error
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) IsRevoked(sn string) (bool, error) {
|
||||
if m.isRevoked != nil {
|
||||
return m.isRevoked(sn)
|
||||
}
|
||||
return m.ret1.(bool), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) {
|
||||
if m.isSSHRevoked != nil {
|
||||
return m.isSSHRevoked(sn)
|
||||
}
|
||||
return m.ret1.(bool), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) UseToken(id, tok string) (bool, error) {
|
||||
if m.useToken != nil {
|
||||
return m.useToken(id, tok)
|
||||
}
|
||||
if m.ret1 == nil {
|
||||
return false, m.err
|
||||
}
|
||||
return m.ret1.(bool), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) Revoke(rci *db.RevokedCertificateInfo) error {
|
||||
if m.revoke != nil {
|
||||
return m.revoke(rci)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) RevokeSSH(rci *db.RevokedCertificateInfo) error {
|
||||
if m.revokeSSH != nil {
|
||||
return m.revokeSSH(rci)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
|
||||
if m.storeCertificate != nil {
|
||||
return m.storeCertificate(crt)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) {
|
||||
if m.isSSHHost != nil {
|
||||
return m.isSSHHost(principal)
|
||||
}
|
||||
return m.ret1.(bool), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
||||
if m.storeSSHCertificate != nil {
|
||||
return m.storeSSHCertificate(crt)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) {
|
||||
if m.getSSHHostPrincipals != nil {
|
||||
return m.getSSHHostPrincipals()
|
||||
}
|
||||
return m.ret1.([]string), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthDB) Shutdown() error {
|
||||
if m.shutdown != nil {
|
||||
return m.shutdown()
|
||||
}
|
||||
return m.err
|
||||
}
|
|
@ -1,67 +0,0 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type apiCtx map[string]interface{}
|
||||
|
||||
// Error implements the api.Error interface and adds context to error messages.
|
||||
type apiError struct {
|
||||
err error
|
||||
code int
|
||||
context apiCtx
|
||||
}
|
||||
|
||||
// Cause implements the errors.Causer interface and returns the original error.
|
||||
func (e *apiError) Cause() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// Error returns an error message with additional context.
|
||||
func (e *apiError) Error() string {
|
||||
ret := e.err.Error()
|
||||
|
||||
/*
|
||||
if len(e.context) > 0 {
|
||||
ret += "\n\nContext:"
|
||||
for k, v := range e.context {
|
||||
ret += fmt.Sprintf("\n %s: %v", k, v)
|
||||
}
|
||||
}
|
||||
*/
|
||||
return ret
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error in JSON format.
|
||||
type ErrorResponse struct {
|
||||
Status int `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// StatusCode returns an http status code indicating the type and severity of
|
||||
// the error.
|
||||
func (e *apiError) StatusCode() int {
|
||||
if e.code == 0 {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
return e.code
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaller interface for the Error struct.
|
||||
func (e *apiError) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(&ErrorResponse{Status: e.code, Message: http.StatusText(e.code)})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
|
||||
func (e *apiError) UnmarshalJSON(data []byte) error {
|
||||
var er ErrorResponse
|
||||
if err := json.Unmarshal(data, &er); err != nil {
|
||||
return err
|
||||
}
|
||||
e.code = er.Status
|
||||
e.err = fmt.Errorf(er.Message)
|
||||
return nil
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/x509"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// ACME is the acme provisioner type, an entity that can authorize the ACME
|
||||
|
@ -79,7 +80,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -3,11 +3,13 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
func TestACME_Getters(t *testing.T) {
|
||||
|
@ -88,86 +90,98 @@ func TestACME_Init(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestACME_AuthorizeRevoke(t *testing.T) {
|
||||
func TestACME_AuthorizeRenew(t *testing.T) {
|
||||
type test struct {
|
||||
p *ACME
|
||||
cert *x509.Certificate
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/renew-disabled": func(t *testing.T) test {
|
||||
p, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
assert.Nil(t, p.AuthorizeRevoke(context.TODO(), ""))
|
||||
}
|
||||
|
||||
func TestACME_AuthorizeRenew(t *testing.T) {
|
||||
p1, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
p.Claims = &Claims{DisableRenewal: &disable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
return test{
|
||||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()),
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *ACME
|
||||
args args
|
||||
err error
|
||||
}{
|
||||
{"ok", p1, args{nil}, nil},
|
||||
{"fail", p2, args{nil}, errors.Errorf("renew is disabled for provisioner %s", p2.GetID())},
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tt.err)
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestACME_AuthorizeSign(t *testing.T) {
|
||||
p1, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *ACME
|
||||
method Method
|
||||
type test struct {
|
||||
p *ACME
|
||||
token string
|
||||
code int
|
||||
err error
|
||||
}{
|
||||
{"fail/method", p1, SignSSHMethod, errors.New("unexpected method type 1 in context")},
|
||||
{"ok", p1, SignMethod, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), tt.method)
|
||||
if got, err := tt.prov.AuthorizeSign(ctx, ""); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, got) {
|
||||
assert.Len(t, 4, got)
|
||||
|
||||
for _, o := range got {
|
||||
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
|
||||
assert.Len(t, 4, opts)
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeACME))
|
||||
assert.Equals(t, v.Name, tt.prov.GetName())
|
||||
assert.Equals(t, v.Name, tc.p.GetName())
|
||||
assert.Equals(t, v.CredentialID, "")
|
||||
assert.Len(t, 0, v.KeyValuePairs)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
|
||||
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
|
||||
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -271,7 +272,7 @@ func (p *AWS) Init(config Config) (err error) {
|
|||
func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
payload, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign")
|
||||
}
|
||||
|
||||
doc := payload.document
|
||||
|
@ -305,7 +306,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -349,41 +350,41 @@ func (p *AWS) readURL(url string) ([]byte, error) {
|
|||
func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token")
|
||||
}
|
||||
if len(jwt.Headers) == 0 {
|
||||
return nil, errors.New("error parsing token: header is missing")
|
||||
return nil, errs.InternalServer("aws.authorizeToken; error parsing token, header is missing")
|
||||
}
|
||||
|
||||
var unsafeClaims awsPayload
|
||||
if err := jwt.UnsafeClaimsWithoutVerification(&unsafeClaims); err != nil {
|
||||
return nil, errors.Wrap(err, "error unmarshaling claims")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling claims")
|
||||
}
|
||||
|
||||
var payload awsPayload
|
||||
if err := jwt.Claims(unsafeClaims.Amazon.Signature, &payload); err != nil {
|
||||
return nil, errors.Wrap(err, "error verifying claims")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error verifying claims")
|
||||
}
|
||||
|
||||
// Validate identity document signature
|
||||
if err := p.checkSignature(payload.Amazon.Document, payload.Amazon.Signature); err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token signature")
|
||||
}
|
||||
|
||||
var doc awsInstanceIdentityDocument
|
||||
if err := json.Unmarshal(payload.Amazon.Document, &doc); err != nil {
|
||||
return nil, errors.Wrap(err, "error unmarshaling identity document")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling aws identity document")
|
||||
}
|
||||
|
||||
switch {
|
||||
case doc.AccountID == "":
|
||||
return nil, errors.New("identity document accountId cannot be empty")
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document accountId cannot be empty")
|
||||
case doc.InstanceID == "":
|
||||
return nil, errors.New("identity document instanceId cannot be empty")
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document instanceId cannot be empty")
|
||||
case doc.PrivateIP == "":
|
||||
return nil, errors.New("identity document privateIp cannot be empty")
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document privateIp cannot be empty")
|
||||
case doc.Region == "":
|
||||
return nil, errors.New("identity document region cannot be empty")
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document region cannot be empty")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
|
@ -393,12 +394,12 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
Issuer: awsIssuer,
|
||||
Time: now,
|
||||
}, time.Minute); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid token")
|
||||
return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token")
|
||||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(payload.Audience, p.audiences.Sign) {
|
||||
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
// Validate subject, it has to be known if disableCustomSANs is enabled
|
||||
|
@ -406,7 +407,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
if payload.Subject != doc.InstanceID &&
|
||||
payload.Subject != doc.PrivateIP &&
|
||||
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
|
||||
return nil, errors.New("invalid token: invalid subject claim (sub)")
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -420,14 +421,14 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.New("invalid identity document: accountId is not valid")
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid aws identity document - accountId is not valid")
|
||||
}
|
||||
}
|
||||
|
||||
// validate instance age
|
||||
if d := p.InstanceAge.Value(); d > 0 {
|
||||
if now.Sub(doc.PendingTime) > d {
|
||||
return nil, errors.New("identity document pendingTime is too old")
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document pendingTime is too old")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -438,18 +439,18 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())
|
||||
}
|
||||
claims, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSSHSign")
|
||||
}
|
||||
|
||||
doc := claims.document
|
||||
|
||||
signOptions := []SignOption{
|
||||
// set the key id to the token subject
|
||||
sshCertificateKeyIDModifier(claims.Subject),
|
||||
sshCertKeyIDModifier(claims.Subject),
|
||||
}
|
||||
|
||||
// Default to host + known IPs/hostnames
|
||||
|
@ -461,9 +462,9 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
},
|
||||
}
|
||||
// Validate user options
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
|
||||
// Set defaults if not given as user options
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
|
||||
|
||||
return append(signOptions,
|
||||
// Set the default extensions.
|
||||
|
@ -473,8 +474,8 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -10,12 +10,15 @@ import (
|
|||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -229,6 +232,213 @@ func TestAWS_Init(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAWS_authorizeToken(t *testing.T) {
|
||||
block, _ := pem.Decode([]byte(awsTestKey))
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
t.Fatal("error decoding AWS key")
|
||||
}
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
assert.FatalError(t, err)
|
||||
badKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type test struct {
|
||||
p *AWS
|
||||
token string
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/bad-token": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; error parsing aws token"),
|
||||
}
|
||||
},
|
||||
"fail/cannot-validate-sig": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
||||
"127.0.0.1", "us-west-1", time.Now(), badKey)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; invalid aws token signature"),
|
||||
}
|
||||
},
|
||||
"fail/empty-account-id": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", awsIssuer, p.GetID(), "", "instance-id",
|
||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; aws identity document accountId cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail/empty-instance-id": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "",
|
||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail/empty-private-ip": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
||||
"", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail/empty-region": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
||||
"127.0.0.1", "", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; aws identity document region cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-token-issuer": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", "bad-issuer", p.GetID(), p.Accounts[0], "instance-id",
|
||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; invalid aws token"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-audience": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", awsIssuer, "bad-audience", p.Accounts[0], "instance-id",
|
||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-subject-disabled-custom-SANs": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
p.DisableCustomSANs = true
|
||||
tok, err := generateAWSToken(
|
||||
"foo", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-account-id": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", awsIssuer, p.GetID(), "foo", "instance-id",
|
||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid"),
|
||||
}
|
||||
},
|
||||
"fail/instance-age": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
p.InstanceAge = Duration{1 * time.Minute}
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
||||
"127.0.0.1", "us-west-1", time.Now().Add(-1*time.Minute), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("aws.authorizeToken; aws identity document pendingTime is too old"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAWSToken(
|
||||
"instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id",
|
||||
"127.0.0.1", "us-west-1", time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) && assert.NotNil(t, claims) {
|
||||
assert.Equals(t, claims.Subject, "instance-id")
|
||||
assert.Equals(t, claims.Issuer, awsIssuer)
|
||||
assert.NotNil(t, claims.Amazon)
|
||||
|
||||
aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID())
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, claims.Audience[0], aud)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAWS_AuthorizeSign(t *testing.T) {
|
||||
p1, srv, err := generateAWSWithServer()
|
||||
assert.FatalError(t, err)
|
||||
|
@ -326,26 +536,27 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
aws *AWS
|
||||
args args
|
||||
wantLen int
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, 5, false},
|
||||
{"ok", p2, args{t2}, 7, false},
|
||||
{"ok", p2, args{t2Hostname}, 7, false},
|
||||
{"ok", p2, args{t2PrivateIP}, 7, false},
|
||||
{"ok", p1, args{t4}, 5, false},
|
||||
{"fail account", p3, args{t3}, 0, true},
|
||||
{"fail token", p1, args{"token"}, 0, true},
|
||||
{"fail subject", p1, args{failSubject}, 0, true},
|
||||
{"fail issuer", p1, args{failIssuer}, 0, true},
|
||||
{"fail audience", p1, args{failAudience}, 0, true},
|
||||
{"fail account", p1, args{failAccount}, 0, true},
|
||||
{"fail instanceID", p1, args{failInstanceID}, 0, true},
|
||||
{"fail privateIP", p1, args{failPrivateIP}, 0, true},
|
||||
{"fail region", p1, args{failRegion}, 0, true},
|
||||
{"fail exp", p1, args{failExp}, 0, true},
|
||||
{"fail nbf", p1, args{failNbf}, 0, true},
|
||||
{"fail key", p1, args{failKey}, 0, true},
|
||||
{"fail instance age", p2, args{failInstanceAge}, 0, true},
|
||||
{"ok", p1, args{t1}, 5, http.StatusOK, false},
|
||||
{"ok", p2, args{t2}, 7, http.StatusOK, false},
|
||||
{"ok", p2, args{t2Hostname}, 7, http.StatusOK, false},
|
||||
{"ok", p2, args{t2PrivateIP}, 7, http.StatusOK, false},
|
||||
{"ok", p1, args{t4}, 5, http.StatusOK, false},
|
||||
{"fail account", p3, args{t3}, 0, http.StatusUnauthorized, true},
|
||||
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
|
||||
{"fail subject", p1, args{failSubject}, 0, http.StatusUnauthorized, true},
|
||||
{"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
|
||||
{"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
|
||||
{"fail account", p1, args{failAccount}, 0, http.StatusUnauthorized, true},
|
||||
{"fail instanceID", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true},
|
||||
{"fail privateIP", p1, args{failPrivateIP}, 0, http.StatusUnauthorized, true},
|
||||
{"fail region", p1, args{failRegion}, 0, http.StatusUnauthorized, true},
|
||||
{"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
|
||||
{"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
|
||||
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
|
||||
{"fail instance age", p2, args{failInstanceAge}, 0, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -354,8 +565,13 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
} else {
|
||||
assert.Len(t, tt.wantLen, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -368,6 +584,14 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
p2, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
// disable sshCA
|
||||
disable := false
|
||||
p2.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
@ -407,30 +631,35 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
|||
aws *AWS
|
||||
args args
|
||||
expected *SSHOptions
|
||||
code int
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, false, false},
|
||||
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, false, true},
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, http.StatusOK, false, false},
|
||||
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, http.StatusOK, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
|
||||
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.aws.AuthorizeSSHSign(ctx, tt.args.token)
|
||||
got, err := tt.aws.AuthorizeSSHSign(context.Background(), tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AWS.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
|
@ -447,6 +676,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAWS_AuthorizeRenew(t *testing.T) {
|
||||
p1, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
|
@ -466,44 +696,20 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
|
|||
name string
|
||||
aws *AWS
|
||||
args args
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.aws.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAWS_AuthorizeRevoke(t *testing.T) {
|
||||
p1, srv, err := generateAWSWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
aws *AWS
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, true}, // revoke is disabled
|
||||
{"fail", p1, args{"token"}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.aws.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("AWS.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -209,14 +210,14 @@ func (p *Azure) Init(config Config) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// parseToken returns the claims, name, group, error.
|
||||
func (p *Azure) parseToken(token string) (*azurePayload, string, string, error) {
|
||||
// authorizeToken returns the claims, name, group, error.
|
||||
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, "", "", errors.Wrapf(err, "error parsing token")
|
||||
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
|
||||
}
|
||||
if len(jwt.Headers) == 0 {
|
||||
return nil, "", "", errors.New("error parsing token: header is missing")
|
||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
|
||||
}
|
||||
|
||||
var found bool
|
||||
|
@ -229,7 +230,7 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error)
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, "", "", errors.New("cannot validate token")
|
||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
|
||||
}
|
||||
|
||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||
|
@ -237,17 +238,17 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error)
|
|||
Issuer: p.oidcConfig.Issuer,
|
||||
Time: time.Now(),
|
||||
}, 1*time.Minute); err != nil {
|
||||
return nil, "", "", errors.Wrap(err, "failed to validate payload")
|
||||
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload")
|
||||
}
|
||||
|
||||
// Validate TenantID
|
||||
if claims.TenantID != p.TenantID {
|
||||
return nil, "", "", errors.New("validation failed: invalid tenant id claim (tid)")
|
||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
|
||||
}
|
||||
|
||||
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
|
||||
if len(re) != 4 {
|
||||
return nil, "", "", errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID)
|
||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
|
||||
}
|
||||
group, name := re[2], re[3]
|
||||
return &claims, name, group, nil
|
||||
|
@ -256,9 +257,9 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error)
|
|||
// AuthorizeSign validates the given token and returns the sign options that
|
||||
// will be used on certificate creation.
|
||||
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
_, name, group, err := p.parseToken(token)
|
||||
_, name, group, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign")
|
||||
}
|
||||
|
||||
// Filter by resource group
|
||||
|
@ -271,7 +272,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.New("validation failed: invalid resource group")
|
||||
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid resource group")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -301,7 +302,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -309,16 +310,16 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
|
||||
_, name, _, err := p.parseToken(token)
|
||||
_, name, _, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign")
|
||||
}
|
||||
signOptions := []SignOption{
|
||||
// set the key id to the token subject
|
||||
sshCertificateKeyIDModifier(name),
|
||||
sshCertKeyIDModifier(name),
|
||||
}
|
||||
|
||||
// Default to host + known hostnames
|
||||
|
@ -327,9 +328,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
Principals: []string{name},
|
||||
}
|
||||
// Validate user options
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
|
||||
// Set defaults if not given as user options
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
|
||||
|
||||
return append(signOptions,
|
||||
// Set the default extensions.
|
||||
|
@ -339,9 +340,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,10 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func TestAzure_Getters(t *testing.T) {
|
||||
|
@ -209,6 +212,148 @@ func TestAzure_Init(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAzure_authorizeToken(t *testing.T) {
|
||||
type test struct {
|
||||
p *Azure
|
||||
token string
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/bad-token": func(t *testing.T) test {
|
||||
p, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("azure.authorizeToken; error parsing azure token"),
|
||||
}
|
||||
},
|
||||
"fail/cannot-validate-sig": func(t *testing.T) test {
|
||||
p, srv, err := generateAzureWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
time.Now(), jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("azure.authorizeToken; cannot validate azure token"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-token-issuer": func(t *testing.T) test {
|
||||
p, srv, err := generateAzureWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
|
||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("azure.authorizeToken; failed to validate azure token payload"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-tenant-id": func(t *testing.T) test {
|
||||
p, srv, err := generateAzureWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
||||
"foo", "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-xms-mir-id": func(t *testing.T) test {
|
||||
p, srv, err := generateAzureWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
jwk := &p.keyStore.keySet.Keys[0]
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||
)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
now := time.Now()
|
||||
claims := azurePayload{
|
||||
Claims: jose.Claims{
|
||||
Subject: "subject",
|
||||
Issuer: p.oidcConfig.Issuer,
|
||||
IssuedAt: jose.NewNumericDate(now),
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
Audience: []string{azureDefaultAudience},
|
||||
ID: "the-jti",
|
||||
},
|
||||
AppID: "the-appid",
|
||||
AppIDAcr: "the-appidacr",
|
||||
IdentityProvider: "the-idp",
|
||||
ObjectID: "the-oid",
|
||||
TenantID: p.TenantID,
|
||||
Version: "the-version",
|
||||
XMSMirID: "foo",
|
||||
}
|
||||
tok, err := jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("azure.authorizeToken; error parsing xms_mirid claim - foo"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, srv, err := generateAzureWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if claims, name, group, err := tc.p.authorizeToken(tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, claims.Subject, "subject")
|
||||
assert.Equals(t, claims.Issuer, tc.p.oidcConfig.Issuer)
|
||||
assert.Equals(t, claims.Audience[0], azureDefaultAudience)
|
||||
|
||||
assert.Equals(t, name, "virtualMachine")
|
||||
assert.Equals(t, group, "resourceGroup")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_AuthorizeSign(t *testing.T) {
|
||||
p1, srv, err := generateAzureWithServer()
|
||||
assert.FatalError(t, err)
|
||||
|
@ -283,19 +428,20 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
azure *Azure
|
||||
args args
|
||||
wantLen int
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, 4, false},
|
||||
{"ok", p2, args{t2}, 6, false},
|
||||
{"ok", p1, args{t11}, 4, false},
|
||||
{"fail tenant", p3, args{t3}, 0, true},
|
||||
{"fail resource group", p4, args{t4}, 0, true},
|
||||
{"fail token", p1, args{"token"}, 0, true},
|
||||
{"fail issuer", p1, args{failIssuer}, 0, true},
|
||||
{"fail audience", p1, args{failAudience}, 0, true},
|
||||
{"fail exp", p1, args{failExp}, 0, true},
|
||||
{"fail nbf", p1, args{failNbf}, 0, true},
|
||||
{"fail key", p1, args{failKey}, 0, true},
|
||||
{"ok", p1, args{t1}, 4, http.StatusOK, false},
|
||||
{"ok", p2, args{t2}, 6, http.StatusOK, false},
|
||||
{"ok", p1, args{t11}, 4, http.StatusOK, false},
|
||||
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
|
||||
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
|
||||
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
|
||||
{"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
|
||||
{"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
|
||||
{"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
|
||||
{"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
|
||||
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -304,8 +450,51 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
} else {
|
||||
assert.Len(t, tt.wantLen, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_AuthorizeRenew(t *testing.T) {
|
||||
p1, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
azure *Azure
|
||||
args args
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -318,6 +507,14 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
p2, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
// disable sshCA
|
||||
disable := false
|
||||
p2.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := p1.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
@ -349,28 +546,33 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
|
|||
azure *Azure
|
||||
args args
|
||||
expected *SSHOptions
|
||||
code int
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, false, true},
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
|
||||
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.azure.AuthorizeSSHSign(ctx, tt.args.token)
|
||||
got, err := tt.azure.AuthorizeSSHSign(context.Background(), tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
|
@ -388,68 +590,6 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAzure_AuthorizeRenew(t *testing.T) {
|
||||
p1, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
azure *Azure
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.azure.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_AuthorizeRevoke(t *testing.T) {
|
||||
az, srv, err := generateAzureWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
token, err := az.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
azure *Azure
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok token", az, args{token}, true}, // revoke is disabled
|
||||
{"bad token", az, args{"bad token"}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.azure.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_assertConfig(t *testing.T) {
|
||||
p1, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
|
|
|
@ -78,7 +78,7 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
|
|||
|
||||
// match with server audiences
|
||||
if matchesAudience(claims.Audience, audiences) {
|
||||
// Use fragment to get provisioner name (GCP, AWS)
|
||||
// Use fragment to get provisioner name (GCP, AWS, SSHPOP)
|
||||
if fragment != "" {
|
||||
return c.Load(fragment)
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -210,7 +211,7 @@ func (p *GCP) Init(config Config) error {
|
|||
func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign")
|
||||
}
|
||||
|
||||
ce := claims.Google.ComputeEngine
|
||||
|
@ -239,10 +240,10 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
), nil
|
||||
}
|
||||
|
||||
// AuthorizeRenewal returns an error if the renewal is disabled.
|
||||
func (p *GCP) AuthorizeRenewal(ctx context.Context, cert *x509.Certificate) error {
|
||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -260,10 +261,10 @@ func (p *GCP) assertConfig() {
|
|||
func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token")
|
||||
}
|
||||
if len(jwt.Headers) == 0 {
|
||||
return nil, errors.New("error parsing token: header is missing")
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; error parsing gcp token - header is missing")
|
||||
}
|
||||
|
||||
var found bool
|
||||
|
@ -277,7 +278,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.Errorf("failed to validate payload: cannot find key for kid %s", kid)
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid)
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
|
@ -287,12 +288,12 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
Issuer: "https://accounts.google.com",
|
||||
Time: now,
|
||||
}, time.Minute); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid token")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; invalid gcp token payload")
|
||||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, p.audiences.Sign) {
|
||||
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
// validate subject (service account)
|
||||
|
@ -305,7 +306,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.New("invalid token: invalid subject claim")
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid subject claim")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -319,26 +320,26 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.New("invalid token: invalid project id")
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid project id")
|
||||
}
|
||||
}
|
||||
|
||||
// validate instance age
|
||||
if d := p.InstanceAge.Value(); d > 0 {
|
||||
if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d {
|
||||
return nil, errors.New("token google.compute_engine.instance_creation_timestamp is too old")
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old")
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case claims.Google.ComputeEngine.InstanceID == "":
|
||||
return nil, errors.New("token google.compute_engine.instance_id cannot be empty")
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty")
|
||||
case claims.Google.ComputeEngine.InstanceName == "":
|
||||
return nil, errors.New("token google.compute_engine.instance_name cannot be empty")
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty")
|
||||
case claims.Google.ComputeEngine.ProjectID == "":
|
||||
return nil, errors.New("token google.compute_engine.project_id cannot be empty")
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty")
|
||||
case claims.Google.ComputeEngine.Zone == "":
|
||||
return nil, errors.New("token google.compute_engine.zone cannot be empty")
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
|
@ -347,18 +348,18 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())
|
||||
}
|
||||
claims, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSSHSign")
|
||||
}
|
||||
|
||||
ce := claims.Google.ComputeEngine
|
||||
|
||||
signOptions := []SignOption{
|
||||
// set the key id to the token subject
|
||||
sshCertificateKeyIDModifier(ce.InstanceName),
|
||||
sshCertKeyIDModifier(ce.InstanceName),
|
||||
}
|
||||
|
||||
// Default to host + known hostnames
|
||||
|
@ -370,9 +371,9 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
},
|
||||
}
|
||||
// Validate user options
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
|
||||
// Set defaults if not given as user options
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
|
||||
|
||||
return append(signOptions,
|
||||
// Set the default extensions
|
||||
|
@ -382,8 +383,8 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -16,7 +16,10 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func TestGCP_Getters(t *testing.T) {
|
||||
|
@ -211,6 +214,202 @@ func TestGCP_Init(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGCP_authorizeToken(t *testing.T) {
|
||||
type test struct {
|
||||
p *GCP
|
||||
token string
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/bad-token": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; error parsing gcp token"),
|
||||
}
|
||||
},
|
||||
"fail/cannot-validate-sig": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateGCPToken(p.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "zone",
|
||||
time.Now(), jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid "),
|
||||
}
|
||||
},
|
||||
"fail/invalid-issuer": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateGCPToken(p.ServiceAccounts[0],
|
||||
"https://foo.bar.zap", p.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "zone",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; invalid gcp token payload"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-serviceAccount": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateGCPToken("foo",
|
||||
"https://accounts.google.com", p.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "zone",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-projectID": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
p.ProjectIDs = []string{"foo", "bar"}
|
||||
tok, err := generateGCPToken(p.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "zone",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; invalid gcp token - invalid project id"),
|
||||
}
|
||||
},
|
||||
"fail/instance-age": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
p.InstanceAge = Duration{1 * time.Minute}
|
||||
tok, err := generateGCPToken(p.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "zone",
|
||||
time.Now().Add(-1*time.Minute), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old"),
|
||||
}
|
||||
},
|
||||
"fail/empty-instance-id": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateGCPToken(p.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p.GetID(),
|
||||
"", "instance-name", "project-id", "zone",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail/empty-instance-name": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateGCPToken(p.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p.GetID(),
|
||||
"instance-id", "", "project-id", "zone",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail/empty-project-id": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateGCPToken(p.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p.GetID(),
|
||||
"instance-id", "instance-name", "", "zone",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail/empty-zone": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateGCPToken(p.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateGCPToken(p.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "zone",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) && assert.NotNil(t, claims) {
|
||||
assert.Equals(t, claims.Subject, tc.p.ServiceAccounts[0])
|
||||
assert.Equals(t, claims.Issuer, "https://accounts.google.com")
|
||||
assert.NotNil(t, claims.Google)
|
||||
|
||||
aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID())
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, claims.Audience[0], aud)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGCP_AuthorizeSign(t *testing.T) {
|
||||
p1, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
|
@ -313,24 +512,25 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
gcp *GCP
|
||||
args args
|
||||
wantLen int
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, 4, false},
|
||||
{"ok", p2, args{t2}, 6, false},
|
||||
{"ok", p3, args{t3}, 4, false},
|
||||
{"fail token", p1, args{"token"}, 0, true},
|
||||
{"fail key", p1, args{failKey}, 0, true},
|
||||
{"fail iss", p1, args{failIss}, 0, true},
|
||||
{"fail aud", p1, args{failAud}, 0, true},
|
||||
{"fail exp", p1, args{failExp}, 0, true},
|
||||
{"fail nbf", p1, args{failNbf}, 0, true},
|
||||
{"fail service account", p1, args{failServiceAccount}, 0, true},
|
||||
{"fail invalid project id", p3, args{failInvalidProjectID}, 0, true},
|
||||
{"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, true},
|
||||
{"fail instance id", p1, args{failInstanceID}, 0, true},
|
||||
{"fail instance name", p1, args{failInstanceName}, 0, true},
|
||||
{"fail project id", p1, args{failProjectID}, 0, true},
|
||||
{"fail zone", p1, args{failZone}, 0, true},
|
||||
{"ok", p1, args{t1}, 4, http.StatusOK, false},
|
||||
{"ok", p2, args{t2}, 6, http.StatusOK, false},
|
||||
{"ok", p3, args{t3}, 4, http.StatusOK, false},
|
||||
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
|
||||
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
|
||||
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
|
||||
{"fail aud", p1, args{failAud}, 0, http.StatusUnauthorized, true},
|
||||
{"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
|
||||
{"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
|
||||
{"fail service account", p1, args{failServiceAccount}, 0, http.StatusUnauthorized, true},
|
||||
{"fail invalid project id", p3, args{failInvalidProjectID}, 0, http.StatusUnauthorized, true},
|
||||
{"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, http.StatusUnauthorized, true},
|
||||
{"fail instance id", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true},
|
||||
{"fail instance name", p1, args{failInstanceName}, 0, http.StatusUnauthorized, true},
|
||||
{"fail project id", p1, args{failProjectID}, 0, http.StatusUnauthorized, true},
|
||||
{"fail zone", p1, args{failZone}, 0, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -339,8 +539,13 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
} else {
|
||||
assert.Len(t, tt.wantLen, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -352,6 +557,14 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
|||
p1, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p2, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
// disable sshCA
|
||||
disable := false
|
||||
p2.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := generateGCPToken(p1.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p1.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "zone",
|
||||
|
@ -394,30 +607,35 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
|||
gcp *GCP
|
||||
args args
|
||||
expected *SSHOptions
|
||||
code int
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, false, false},
|
||||
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, false, true},
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, http.StatusOK, false, false},
|
||||
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, http.StatusOK, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
|
||||
{"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.gcp.AuthorizeSSHSign(ctx, tt.args.token)
|
||||
got, err := tt.gcp.AuthorizeSSHSign(context.Background(), tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GCP.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
|
@ -435,7 +653,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGCP_AuthorizeRenewal(t *testing.T) {
|
||||
func TestGCP_AuthorizeRenew(t *testing.T) {
|
||||
p1, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateGCP()
|
||||
|
@ -454,46 +672,20 @@ func TestGCP_AuthorizeRenewal(t *testing.T) {
|
|||
name string
|
||||
prov *GCP
|
||||
args args
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenewal(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("GCP.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGCP_AuthorizeRevoke(t *testing.T) {
|
||||
p1, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := generateGCPToken(p1.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p1.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "zone",
|
||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
gcp *GCP
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, true}, // revoke is disabled
|
||||
{"fail", p1, args{"token"}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.gcp.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("GCP.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,9 +3,11 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
@ -99,12 +101,12 @@ func (p *JWK) Init(config Config) (err error) {
|
|||
func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk token")
|
||||
}
|
||||
|
||||
var claims jwtPayload
|
||||
if err = jwt.Claims(p.Key, &claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing claims")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk claims")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
|
@ -113,17 +115,17 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
|
|||
Issuer: p.Name,
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid token")
|
||||
return nil, errs.Wrapf(http.StatusUnauthorized, err, "jwk.authorizeToken; invalid jwk claims")
|
||||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, audiences) {
|
||||
return nil, errors.Errorf("invalid token: invalid audience claim (aud); want %s, but got %s",
|
||||
return nil, errs.Unauthorized("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s",
|
||||
audiences, claims.Audience)
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errors.New("token subject cannot be empty")
|
||||
return nil, errs.Unauthorized("jwk.authorizeToken; jwk token subject cannot be empty")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
|
@ -133,14 +135,14 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
|
|||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.audiences.Revoke)
|
||||
return err
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
|
||||
}
|
||||
|
||||
// NOTE: This is for backwards compatibility with older versions of cli
|
||||
|
@ -171,7 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -179,20 +181,20 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())
|
||||
}
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
|
||||
}
|
||||
if claims.Step == nil || claims.Step.SSH == nil {
|
||||
return nil, errors.New("authorization token must be an SSH provisioning token")
|
||||
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token")
|
||||
}
|
||||
|
||||
opts := claims.Step.SSH
|
||||
signOptions := []SignOption{
|
||||
// validates user's SSHOptions with the ones in the token
|
||||
sshCertificateOptionsValidator(*opts),
|
||||
sshCertOptionsValidator(*opts),
|
||||
}
|
||||
|
||||
t := now()
|
||||
|
@ -205,19 +207,19 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
signOptions = append(signOptions, sshCertPrincipalsModifier(opts.Principals))
|
||||
}
|
||||
if !opts.ValidAfter.IsZero() {
|
||||
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
|
||||
signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
|
||||
}
|
||||
if !opts.ValidBefore.IsZero() {
|
||||
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
|
||||
signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
|
||||
}
|
||||
if opts.KeyID != "" {
|
||||
signOptions = append(signOptions, sshCertificateKeyIDModifier(opts.KeyID))
|
||||
signOptions = append(signOptions, sshCertKeyIDModifier(opts.KeyID))
|
||||
} else {
|
||||
signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject))
|
||||
signOptions = append(signOptions, sshCertKeyIDModifier(claims.Subject))
|
||||
}
|
||||
|
||||
// Default to a user certificate with no principals if not set
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert})
|
||||
signOptions = append(signOptions, sshCertDefaultsModifier{CertType: SSHUserCert})
|
||||
|
||||
return append(signOptions,
|
||||
// Set the default extensions.
|
||||
|
@ -229,14 +231,14 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
|
||||
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.audiences.SSHRevoke)
|
||||
return err
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
|
||||
}
|
||||
|
|
|
@ -7,12 +7,14 @@ import (
|
|||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -162,25 +164,29 @@ func TestJWK_authorizeToken(t *testing.T) {
|
|||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
code int
|
||||
err error
|
||||
}{
|
||||
{"fail-token", p1, args{failTok}, errors.New("error parsing token")},
|
||||
{"fail-key", p1, args{failKey}, errors.New("error parsing claims")},
|
||||
{"fail-claims", p1, args{failClaims}, errors.New("error parsing claims")},
|
||||
{"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")},
|
||||
{"fail-issuer", p1, args{failIss}, errors.New("invalid token: square/go-jose/jwt: validation failed, invalid issuer claim (iss)")},
|
||||
{"fail-expired", p1, args{failExp}, errors.New("invalid token: square/go-jose/jwt: validation failed, token is expired (exp)")},
|
||||
{"fail-not-before", p1, args{failNbf}, errors.New("invalid token: square/go-jose/jwt: validation failed, token not valid yet (nbf)")},
|
||||
{"fail-audience", p1, args{failAud}, errors.New("invalid token: invalid audience claim (aud)")},
|
||||
{"fail-subject", p1, args{failSub}, errors.New("token subject cannot be empty")},
|
||||
{"ok", p1, args{t1}, nil},
|
||||
{"ok-no-encrypted-key", p2, args{t2}, nil},
|
||||
{"ok-no-sans", p1, args{t3}, nil},
|
||||
{"fail-token", p1, args{failTok}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk token")},
|
||||
{"fail-key", p1, args{failKey}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")},
|
||||
{"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")},
|
||||
{"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
|
||||
{"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)")},
|
||||
{"fail-expired", p1, args{failExp}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, token is expired (exp)")},
|
||||
{"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, token not valid yet (nbf)")},
|
||||
{"fail-audience", p1, args{failAud}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk token audience claim (aud)")},
|
||||
{"fail-subject", p1, args{failSub}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; jwk token subject cannot be empty")},
|
||||
{"ok", p1, args{t1}, http.StatusOK, nil},
|
||||
{"ok-no-encrypted-key", p2, args{t2}, http.StatusOK, nil},
|
||||
{"ok-no-sans", p1, args{t3}, http.StatusOK, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
} else {
|
||||
|
@ -208,15 +214,19 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
|
|||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
code int
|
||||
err error
|
||||
}{
|
||||
{"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")},
|
||||
{"ok", p1, args{t1}, nil},
|
||||
{"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.AuthorizeRevoke: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
|
||||
{"ok", p1, args{t1}, http.StatusOK, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token); err != nil {
|
||||
if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
}
|
||||
|
@ -246,20 +256,24 @@ func TestJWK_AuthorizeSign(t *testing.T) {
|
|||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
code int
|
||||
err error
|
||||
dns []string
|
||||
emails []string
|
||||
ips []net.IP
|
||||
}{
|
||||
{name: "fail-signature", prov: p1, args: args{failSig}, err: errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")},
|
||||
{"ok-sans", p1, args{t1}, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
{"ok-no-sans", p1, args{t2}, nil, []string{"subject"}, []string{}, []net.IP{}},
|
||||
{name: "fail-signature", prov: p1, args: args{failSig}, code: http.StatusUnauthorized, err: errors.New("jwk.AuthorizeSign: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
|
||||
{"ok-sans", p1, args{t1}, http.StatusOK, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
{"ok-no-sans", p1, args{t2}, http.StatusOK, nil, []string{"subject"}, []string{}, []net.IP{}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
} else {
|
||||
|
@ -315,15 +329,20 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
|
|||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -335,6 +354,14 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
|
|||
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
// disable sshCA
|
||||
disable := false
|
||||
p2.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
@ -382,30 +409,34 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
|
|||
prov *JWK
|
||||
args args
|
||||
expected *SSHOptions
|
||||
code int
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false},
|
||||
{"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false},
|
||||
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false},
|
||||
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
|
||||
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
|
||||
{"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, false, false},
|
||||
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
|
||||
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
|
||||
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, true, false},
|
||||
{"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true},
|
||||
{"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false},
|
||||
{"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false},
|
||||
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false},
|
||||
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false},
|
||||
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false},
|
||||
{"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
|
||||
{"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false},
|
||||
{"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
|
||||
{"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token)
|
||||
got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
|
@ -511,10 +542,9 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk)
|
||||
assert.FatalError(t, err)
|
||||
if got, err := tt.prov.AuthorizeSSHSign(ctx, token); (err != nil) != tt.wantErr {
|
||||
if got, err := tt.prov.AuthorizeSSHSign(context.Background(), token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if !tt.wantErr && assert.NotNil(t, got) {
|
||||
var opts SSHOptions
|
||||
|
@ -535,3 +565,52 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_AuthorizeSSHRevoke(t *testing.T) {
|
||||
type test struct {
|
||||
p *JWK
|
||||
token string
|
||||
code int
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/invalid-token": func(t *testing.T) test {
|
||||
p, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("jwk.AuthorizeSSHRevoke: jwk.authorizeToken; error parsing jwk token"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := decryptJSONWebKey(p.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
tok, err := generateToken("subject", p.Name, testAudiences.SSHRevoke[0], "name@smallstep.com", []string{"127.0.0.1", "max@smallstep.com", "foo"}, time.Now(), jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,8 +6,10 @@ import (
|
|||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
@ -138,7 +140,8 @@ func (p *K8sSA) Init(config Config) (err error) {
|
|||
func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err,
|
||||
"k8ssa.authorizeToken; error parsing k8sSA token")
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -146,7 +149,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
|
|||
claims k8sSAPayload
|
||||
)
|
||||
if p.pubKeys == nil {
|
||||
return nil, errors.New("TokenReview API integration not implemented")
|
||||
return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented")
|
||||
/* NOTE: We plan to support the TokenReview API in a future release.
|
||||
Below is some code that should be useful when we prioritize
|
||||
this integration.
|
||||
|
@ -174,7 +177,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
|
|||
}
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("error validating token and extracting claims")
|
||||
return nil, errs.Unauthorized("k8ssa.authorizeToken; error validating k8sSA token and extracting claims")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
|
@ -182,11 +185,11 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
|
|||
if err = claims.Validate(jose.Expected{
|
||||
Issuer: k8sSAIssuer,
|
||||
}); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid token claims")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "k8ssa.authorizeToken; invalid k8sSA token claims")
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errors.New("token subject cannot be empty")
|
||||
return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA token subject cannot be empty")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
|
@ -196,14 +199,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
|
|||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.audiences.Revoke)
|
||||
return err
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
_, err := p.authorizeToken(token, p.audiences.Sign)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if _, err := p.authorizeToken(token, p.audiences.Sign); err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
|
||||
}
|
||||
|
||||
return []SignOption{
|
||||
|
@ -219,7 +221,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -227,17 +229,14 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
|
|||
// AuthorizeSSHSign validates an request for an SSH certificate.
|
||||
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errors.Errorf("authorizeSSHSign: ssh ca is disabled for provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())
|
||||
}
|
||||
_, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "authorizeSSHSign")
|
||||
if _, err := p.authorizeToken(token, p.audiences.SSHSign); err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
|
||||
}
|
||||
|
||||
// Default to a user certificate with no principals if not set
|
||||
signOptions := []SignOption{
|
||||
sshCertificateDefaultsModifier{CertType: SSHUserCert},
|
||||
}
|
||||
signOptions := []SignOption{sshCertDefaultsModifier{CertType: SSHUserCert}}
|
||||
|
||||
return append(signOptions,
|
||||
// Set the default extensions.
|
||||
|
@ -247,9 +246,9 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,11 +3,13 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -36,6 +38,7 @@ func TestK8sSA_authorizeToken(t *testing.T) {
|
|||
p *K8sSA
|
||||
token string
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/bad-token": func(t *testing.T) test {
|
||||
|
@ -44,7 +47,24 @@ func TestK8sSA_authorizeToken(t *testing.T) {
|
|||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
err: errors.New("error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("k8ssa.authorizeToken; error parsing k8sSA token"),
|
||||
}
|
||||
},
|
||||
"fail/not-implemented": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
p, err := generateK8sSA(nil)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("", p.Name, testAudiences.Sign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk)
|
||||
p.pubKeys = nil
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
"fail/error-validating-token": func(t *testing.T) test {
|
||||
|
@ -58,7 +78,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
|
|||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("error validating token and extracting claims"),
|
||||
err: errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
"fail/invalid-issuer": func(t *testing.T) test {
|
||||
|
@ -73,7 +94,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
|
|||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("invalid token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("k8ssa.authorizeToken; invalid k8sSA token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
@ -94,6 +116,9 @@ func TestK8sSA_authorizeToken(t *testing.T) {
|
|||
tc := tt(t)
|
||||
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
|
@ -105,12 +130,12 @@ func TestK8sSA_authorizeToken(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestK8sSA_AuthorizeSign(t *testing.T) {
|
||||
func TestK8sSA_AuthorizeRevoke(t *testing.T) {
|
||||
type test struct {
|
||||
p *K8sSA
|
||||
token string
|
||||
ctx context.Context
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/invalid-token": func(t *testing.T) test {
|
||||
|
@ -119,21 +144,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
|
|||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
err: errors.New("error parsing token"),
|
||||
}
|
||||
},
|
||||
"fail/ssh-unimplemented": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
p, err := generateK8sSA(jwk.Public().Key)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateK8sSAToken(jwk, nil)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
|
||||
token: tok,
|
||||
err: errors.Errorf("ssh certificates not enabled for k8s ServiceAccount provisioners"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("k8ssa.AuthorizeRevoke: k8ssa.authorizeToken; error parsing k8sSA token"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
@ -145,7 +157,6 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
ctx: NewContextWithMethod(context.Background(), SignMethod),
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
|
@ -153,10 +164,110 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if opts, err := tc.p.AuthorizeSign(tc.ctx, tc.token); err != nil {
|
||||
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
||||
type test struct {
|
||||
p *K8sSA
|
||||
cert *x509.Certificate
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/renew-disabled": func(t *testing.T) test {
|
||||
p, err := generateK8sSA(nil)
|
||||
assert.FatalError(t, err)
|
||||
// disable renewal
|
||||
disable := true
|
||||
p.Claims = &Claims{DisableRenewal: &disable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateK8sSA(nil)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestK8sSA_AuthorizeSign(t *testing.T) {
|
||||
type test struct {
|
||||
p *K8sSA
|
||||
token string
|
||||
code int
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/invalid-token": func(t *testing.T) test {
|
||||
p, err := generateK8sSA(nil)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("k8ssa.AuthorizeSign: k8ssa.authorizeToken; error parsing k8sSA token"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
p, err := generateK8sSA(jwk.Public().Key)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateK8sSAToken(jwk, nil)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
if assert.NotNil(t, opts) {
|
||||
|
@ -187,20 +298,37 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestK8sSA_AuthorizeRevoke(t *testing.T) {
|
||||
func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
||||
type test struct {
|
||||
p *K8sSA
|
||||
token string
|
||||
code int
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/sshCA-disabled": func(t *testing.T) test {
|
||||
p, err := generateK8sSA(nil)
|
||||
assert.FatalError(t, err)
|
||||
// disable sshCA
|
||||
disable := false
|
||||
p.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()),
|
||||
}
|
||||
},
|
||||
"fail/invalid-token": func(t *testing.T) test {
|
||||
p, err := generateK8sSA(nil)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
err: errors.New("error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("k8ssa.AuthorizeSSHSign: k8ssa.authorizeToken; error parsing k8sSA token"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
@ -219,45 +347,36 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil {
|
||||
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
||||
p1, err := generateK8sSA(nil)
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateK8sSA(nil)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *K8sSA
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("X5C.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if assert.Nil(t, tc.err) {
|
||||
if assert.NotNil(t, opts) {
|
||||
tot := 0
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case sshCertDefaultsModifier:
|
||||
assert.Equals(t, v.CertType, SSHUserCert)
|
||||
case *sshDefaultExtensionModifier:
|
||||
case *sshCertValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
case *sshDefaultPublicKeyValidator:
|
||||
case *sshCertDefaultValidator:
|
||||
case *sshDefaultDuration:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
tot++
|
||||
}
|
||||
assert.Equals(t, tot, 6)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -16,14 +16,16 @@ const (
|
|||
SignMethod Method = iota
|
||||
// RevokeMethod is the method used to revoke X.509 certificates.
|
||||
RevokeMethod
|
||||
// SignSSHMethod is the method used to sign SSH certificates.
|
||||
SignSSHMethod
|
||||
// RenewSSHMethod is the method used to renew SSH certificates.
|
||||
RenewSSHMethod
|
||||
// RevokeSSHMethod is the method used to revoke SSH certificates.
|
||||
RevokeSSHMethod
|
||||
// RekeySSHMethod is the method used to rekey SSH certificates.
|
||||
RekeySSHMethod
|
||||
// RenewMethod is the method used to renew X.509 certificates.
|
||||
RenewMethod
|
||||
// SSHSignMethod is the method used to sign SSH certificates.
|
||||
SSHSignMethod
|
||||
// SSHRenewMethod is the method used to renew SSH certificates.
|
||||
SSHRenewMethod
|
||||
// SSHRevokeMethod is the method used to revoke SSH certificates.
|
||||
SSHRevokeMethod
|
||||
// SSHRekeyMethod is the method used to rekey SSH certificates.
|
||||
SSHRekeyMethod
|
||||
)
|
||||
|
||||
// String returns a string representation of the context method.
|
||||
|
@ -33,14 +35,16 @@ func (m Method) String() string {
|
|||
return "sign-method"
|
||||
case RevokeMethod:
|
||||
return "revoke-method"
|
||||
case SignSSHMethod:
|
||||
return "sign-ssh-method"
|
||||
case RenewSSHMethod:
|
||||
return "renew-ssh-method"
|
||||
case RevokeSSHMethod:
|
||||
return "revoke-ssh-method"
|
||||
case RekeySSHMethod:
|
||||
return "rekey-ssh-method"
|
||||
case RenewMethod:
|
||||
return "renew-method"
|
||||
case SSHSignMethod:
|
||||
return "ssh-sign-method"
|
||||
case SSHRenewMethod:
|
||||
return "ssh-renew-method"
|
||||
case SSHRevokeMethod:
|
||||
return "ssh-revoke-method"
|
||||
case SSHRekeyMethod:
|
||||
return "ssh-rekey-method"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@ func Test_noop(t *testing.T) {
|
|||
assert.Equals(t, "noop", p.GetName())
|
||||
assert.Equals(t, noopType, p.GetType())
|
||||
assert.Equals(t, nil, p.Init(Config{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRenew(context.TODO(), &x509.Certificate{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRevoke(context.TODO(), "foo"))
|
||||
assert.Equals(t, nil, p.AuthorizeRenew(context.Background(), &x509.Certificate{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRevoke(context.Background(), "foo"))
|
||||
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
assert.Equals(t, "", kid)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -189,17 +190,17 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
|||
Audience: jose.Audience{o.ClientID},
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return errors.Wrap(err, "failed to validate payload")
|
||||
return errs.Wrap(http.StatusUnauthorized, err, "validatePayload: failed to validate oidc token payload")
|
||||
}
|
||||
|
||||
// Validate azp if present
|
||||
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
|
||||
return errors.New("failed to validate payload: invalid azp")
|
||||
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: invalid azp")
|
||||
}
|
||||
|
||||
// Enforce an email claim
|
||||
if p.Email == "" {
|
||||
return errors.New("failed to validate payload: email not found")
|
||||
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email not found")
|
||||
}
|
||||
|
||||
// Validate domains (case-insensitive)
|
||||
|
@ -213,7 +214,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("failed to validate payload: email is not allowed")
|
||||
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email is not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -229,7 +230,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("validation failed: invalid group")
|
||||
return errs.Unauthorized("validatePayload: oidc token payload validation failed: invalid group")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -241,13 +242,15 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
|||
func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err,
|
||||
"oidc.AuthorizeToken; error parsing oidc token")
|
||||
}
|
||||
|
||||
// Parse claims to get the kid
|
||||
var claims openIDPayload
|
||||
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing claims")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err,
|
||||
"oidc.AuthorizeToken; error parsing oidc token claims")
|
||||
}
|
||||
|
||||
found := false
|
||||
|
@ -260,11 +263,11 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.New("cannot validate token")
|
||||
return nil, errs.Unauthorized("oidc.AuthorizeToken; cannot validate oidc token")
|
||||
}
|
||||
|
||||
if err := o.ValidatePayload(claims); err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeToken")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
|
@ -276,21 +279,21 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
|
|||
func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
claims, err := o.authorizeToken(token)
|
||||
if err != nil {
|
||||
return err
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeRevoke")
|
||||
}
|
||||
|
||||
// Only admins can revoke certificates.
|
||||
if o.IsAdmin(claims.Email) {
|
||||
return nil
|
||||
}
|
||||
return errors.New("cannot revoke with non-admin token")
|
||||
return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := o.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSign")
|
||||
}
|
||||
|
||||
so := []SignOption{
|
||||
|
@ -315,7 +318,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// certificate was configured to allow renewals.
|
||||
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if o.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", o.GetID())
|
||||
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -323,22 +326,22 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !o.claimer.IsSSHCAEnabled() {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", o.GetID())
|
||||
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())
|
||||
}
|
||||
claims, err := o.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
|
||||
}
|
||||
signOptions := []SignOption{
|
||||
// set the key id to the token email
|
||||
sshCertificateKeyIDModifier(claims.Email),
|
||||
sshCertKeyIDModifier(claims.Email),
|
||||
}
|
||||
|
||||
// Get the identity using either the default identityFunc or one injected
|
||||
// externally.
|
||||
iden, err := o.getIdentityFunc(o, claims.Email)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "authorizeSSHSign")
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
|
||||
}
|
||||
defaults := SSHOptions{
|
||||
CertType: SSHUserCert,
|
||||
|
@ -349,12 +352,12 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
|||
// Non-admin users can only use principals returned by the identityFunc, and
|
||||
// can only sign user certificates.
|
||||
if !o.IsAdmin(claims.Email) {
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
|
||||
}
|
||||
|
||||
// Default to a user certificate with usernames as principals if those options
|
||||
// are not set.
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
|
||||
|
||||
return append(signOptions,
|
||||
// Set the default extensions
|
||||
|
@ -364,9 +367,9 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{o.claimer},
|
||||
&sshCertValidityValidator{o.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
|
@ -374,14 +377,14 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
|||
func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
claims, err := o.authorizeToken(token)
|
||||
if err != nil {
|
||||
return err
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHRevoke")
|
||||
}
|
||||
|
||||
// Only admins can revoke certificates.
|
||||
if o.IsAdmin(claims.Email) {
|
||||
return nil
|
||||
if !o.IsAdmin(claims.Email) {
|
||||
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
|
||||
}
|
||||
return errors.New("cannot revoke with non-admin token")
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAndDecode(uri string, v interface{}) error {
|
||||
|
|
|
@ -7,12 +7,14 @@ import (
|
|||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -206,20 +208,21 @@ func TestOIDC_authorizeToken(t *testing.T) {
|
|||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok1", p1, args{t1}, false},
|
||||
{"ok2", p2, args{t2}, false},
|
||||
{"fail-email", p3, args{failEmail}, true},
|
||||
{"fail-domain", p3, args{failDomain}, true},
|
||||
{"fail-key", p1, args{failKey}, true},
|
||||
{"fail-token", p1, args{failTok}, true},
|
||||
{"fail-claims", p1, args{failClaims}, true},
|
||||
{"fail-issuer", p1, args{failIss}, true},
|
||||
{"fail-audience", p1, args{failAud}, true},
|
||||
{"fail-signature", p1, args{failSig}, true},
|
||||
{"fail-expired", p1, args{failExp}, true},
|
||||
{"fail-not-before", p1, args{failNbf}, true},
|
||||
{"ok1", p1, args{t1}, http.StatusOK, false},
|
||||
{"ok2", p2, args{t2}, http.StatusOK, false},
|
||||
{"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
|
||||
{"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, true},
|
||||
{"fail-key", p1, args{failKey}, http.StatusUnauthorized, true},
|
||||
{"fail-token", p1, args{failTok}, http.StatusUnauthorized, true},
|
||||
{"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, true},
|
||||
{"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, true},
|
||||
{"fail-audience", p1, args{failAud}, http.StatusUnauthorized, true},
|
||||
{"fail-signature", p1, args{failSig}, http.StatusUnauthorized, true},
|
||||
{"fail-expired", p1, args{failExp}, http.StatusUnauthorized, true},
|
||||
{"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -230,6 +233,9 @@ func TestOIDC_authorizeToken(t *testing.T) {
|
|||
return
|
||||
}
|
||||
if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NotNil(t, got)
|
||||
|
@ -282,21 +288,24 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
|||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok1", p1, args{t1}, false},
|
||||
{"admin", p3, args{okAdmin}, false},
|
||||
{"fail-email", p3, args{failEmail}, true},
|
||||
{"ok1", p1, args{t1}, http.StatusOK, false},
|
||||
{"admin", p3, args{okAdmin}, http.StatusOK, false},
|
||||
{"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
|
||||
got, err := tt.prov.AuthorizeSign(context.Background(), tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
if assert.NotNil(t, got) {
|
||||
|
@ -330,6 +339,107 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
var keys jose.JSONWebKeySet
|
||||
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||
|
||||
// Create test provisioners
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p3, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
// Admin + Domains
|
||||
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
|
||||
p3.Domains = []string{"smallstep.com"}
|
||||
|
||||
// Update configuration endpoints and initialize
|
||||
config := Config{Claims: globalProvisionerClaims}
|
||||
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
assert.FatalError(t, p1.Init(config))
|
||||
assert.FatalError(t, p3.Init(config))
|
||||
|
||||
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// Admin email not in domains
|
||||
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// Invalid email
|
||||
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok1", p1, args{t1}, http.StatusUnauthorized, true},
|
||||
{"admin", p3, args{okAdmin}, http.StatusOK, false},
|
||||
{"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Println(tt)
|
||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeRenew(t *testing.T) {
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
|
@ -351,9 +461,16 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
p5, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p6, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
// Admin + Domains
|
||||
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
|
||||
p3.Domains = []string{"smallstep.com"}
|
||||
// disable sshCA
|
||||
disable := false
|
||||
p6.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Update configuration endpoints and initialize
|
||||
config := Config{Claims: globalProvisionerClaims}
|
||||
|
@ -425,48 +542,53 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
|||
prov *OIDC
|
||||
args args
|
||||
expected *SSHOptions
|
||||
code int
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false},
|
||||
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false},
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false},
|
||||
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub},
|
||||
&SSHOptions{CertType: "user", Principals: []string{"name"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
|
||||
{"ok-principals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{Principals: []string{"mariano"}}, pub},
|
||||
&SSHOptions{CertType: "user", Principals: []string{"mariano"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
|
||||
{"ok-emptyPrincipals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{}, pub},
|
||||
&SSHOptions{CertType: "user", Principals: []string{"max", "mariano"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub},
|
||||
&SSHOptions{CertType: "user", Principals: []string{"name"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
|
||||
{"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, false, false},
|
||||
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, false, false},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
|
||||
{"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, http.StatusOK, false, false},
|
||||
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, http.StatusOK, false, false},
|
||||
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub},
|
||||
&SSHOptions{CertType: "user", Principals: []string{"root"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
|
||||
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub},
|
||||
&SSHOptions{CertType: "user", Principals: []string{"name"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false},
|
||||
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true},
|
||||
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, false, true},
|
||||
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, false, true},
|
||||
{"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, true, false},
|
||||
{"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, true, false},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
|
||||
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub},
|
||||
expectedHostOptions, http.StatusOK, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
|
||||
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, http.StatusOK, false, true},
|
||||
{"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
|
||||
{"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, http.StatusInternalServerError, true, false},
|
||||
{"fail-sshCA-disabled", p6, args{"foo", SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token)
|
||||
got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
|
@ -484,36 +606,32 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
||||
func TestOIDC_AuthorizeSSHRevoke(t *testing.T) {
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2.Admins = []string{"root@example.com"}
|
||||
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
var keys jose.JSONWebKeySet
|
||||
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||
|
||||
// Create test provisioners
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p3, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
// Admin + Domains
|
||||
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
|
||||
p3.Domains = []string{"smallstep.com"}
|
||||
|
||||
// Update configuration endpoints and initialize
|
||||
config := Config{Claims: globalProvisionerClaims}
|
||||
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
assert.FatalError(t, p1.Init(config))
|
||||
assert.FatalError(t, p3.Init(config))
|
||||
assert.FatalError(t, p2.Init(config))
|
||||
|
||||
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||
// Invalid email
|
||||
failEmail, err := generateToken("subject", "the-issuer", p1.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// Admin email not in domains
|
||||
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
|
||||
noAdmin, err := generateToken("subject", "the-issuer", p1.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// Invalid email
|
||||
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
|
||||
// Admin email in domains
|
||||
okAdmin, err := generateToken("subject", "the-issuer", p2.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
|
@ -523,52 +641,22 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
|||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok1", p1, args{t1}, true},
|
||||
{"admin", p3, args{okAdmin}, false},
|
||||
{"fail-email", p3, args{failEmail}, true},
|
||||
{"ok", p2, args{okAdmin}, http.StatusOK, false},
|
||||
{"fail/invalid-token", p1, args{failEmail}, http.StatusUnauthorized, true},
|
||||
{"fail/not-admin", p1, args{noAdmin}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token)
|
||||
err := tt.prov.AuthorizeSSHRevoke(context.Background(), tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Println(tt)
|
||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeRenew(t *testing.T) {
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
|
@ -283,43 +284,43 @@ type base struct{}
|
|||
// AuthorizeSign returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for signing x509 Certificates.
|
||||
func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
return nil, errors.New("not implemented; provisioner does not implement AuthorizeSign")
|
||||
return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for revoking x509 Certificates.
|
||||
func (b *base) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
return errors.New("not implemented; provisioner does not implement AuthorizeRevoke")
|
||||
return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeRenew returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for renewing x509 Certificates.
|
||||
func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
return errors.New("not implemented; provisioner does not implement AuthorizeRenew")
|
||||
return errs.Unauthorized("provisioner.AuthorizeRenew not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for signing SSH Certificates.
|
||||
func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
return nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHSign")
|
||||
return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for revoking SSH Certificates.
|
||||
func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
return errors.New("not implemented; provisioner does not implement AuthorizeSSHRevoke")
|
||||
return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for renewing SSH Certificates.
|
||||
func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
|
||||
return nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRenew")
|
||||
return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for rekeying SSH Certificates.
|
||||
func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
|
||||
return nil, nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRekey")
|
||||
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
|
||||
}
|
||||
|
||||
// Identity is the type representing an externally supplied identity that is used
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestType_String(t *testing.T) {
|
||||
|
@ -101,3 +105,93 @@ func TestDefaultIdentityFunc(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnimplementedMethods(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
p Interface
|
||||
method Method
|
||||
}{
|
||||
{"jwk/sshRekey", &JWK{}, SSHRekeyMethod},
|
||||
{"jwk/sshRenew", &JWK{}, SSHRenewMethod},
|
||||
{"aws/revoke", &AWS{}, RevokeMethod},
|
||||
{"aws/sshRenew", &AWS{}, SSHRenewMethod},
|
||||
{"aws/rekey", &AWS{}, SSHRekeyMethod},
|
||||
{"aws/sshRevoke", &AWS{}, SSHRevokeMethod},
|
||||
{"azure/revoke", &Azure{}, RevokeMethod},
|
||||
{"azure/sshRenew", &Azure{}, SSHRenewMethod},
|
||||
{"azure/sshRekey", &Azure{}, SSHRekeyMethod},
|
||||
{"azure/sshRevoke", &Azure{}, SSHRevokeMethod},
|
||||
{"gcp/revoke", &GCP{}, RevokeMethod},
|
||||
{"gcp/sshRenew", &GCP{}, SSHRenewMethod},
|
||||
{"gcp/sshRekey", &GCP{}, SSHRekeyMethod},
|
||||
{"gcp/sshRevoke", &GCP{}, SSHRevokeMethod},
|
||||
{"oidc/sshRenew", &OIDC{}, SSHRenewMethod},
|
||||
{"oidc/sshRekey", &OIDC{}, SSHRekeyMethod},
|
||||
{"x5c/sshRenew", &X5C{}, SSHRenewMethod},
|
||||
{"x5c/sshRekey", &X5C{}, SSHRekeyMethod},
|
||||
{"x5c/sshRevoke", &X5C{}, SSHRekeyMethod},
|
||||
{"acme/revoke", &ACME{}, RevokeMethod},
|
||||
{"acme/sshSign", &ACME{}, SSHSignMethod},
|
||||
{"acme/sshRekey", &ACME{}, SSHRekeyMethod},
|
||||
{"acme/sshRenew", &ACME{}, SSHRenewMethod},
|
||||
{"acme/sshRevoke", &ACME{}, SSHRevokeMethod},
|
||||
{"sshpop/sign", &SSHPOP{}, SignMethod},
|
||||
{"sshpop/renew", &SSHPOP{}, RenewMethod},
|
||||
{"sshpop/revoke", &SSHPOP{}, RevokeMethod},
|
||||
{"sshpop/sshSign", &SSHPOP{}, SSHSignMethod},
|
||||
{"k8ssa/sshRekey", &K8sSA{}, SSHRekeyMethod},
|
||||
{"k8ssa/sshRenew", &K8sSA{}, SSHRenewMethod},
|
||||
{"k8ssa/sshRevoke", &K8sSA{}, SSHRevokeMethod},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var (
|
||||
err error
|
||||
msg string
|
||||
)
|
||||
|
||||
switch tt.method {
|
||||
case SignMethod:
|
||||
var signOpts []SignOption
|
||||
signOpts, err = tt.p.AuthorizeSign(context.Background(), "")
|
||||
assert.Nil(t, signOpts)
|
||||
msg = "provisioner.AuthorizeSign not implemented"
|
||||
case RenewMethod:
|
||||
err = tt.p.AuthorizeRenew(context.Background(), nil)
|
||||
msg = "provisioner.AuthorizeRenew not implemented"
|
||||
case RevokeMethod:
|
||||
err = tt.p.AuthorizeRevoke(context.Background(), "")
|
||||
msg = "provisioner.AuthorizeRevoke not implemented"
|
||||
case SSHSignMethod:
|
||||
var signOpts []SignOption
|
||||
signOpts, err = tt.p.AuthorizeSSHSign(context.Background(), "")
|
||||
assert.Nil(t, signOpts)
|
||||
msg = "provisioner.AuthorizeSSHSign not implemented"
|
||||
case SSHRenewMethod:
|
||||
var cert *ssh.Certificate
|
||||
cert, err = tt.p.AuthorizeSSHRenew(context.Background(), "")
|
||||
assert.Nil(t, cert)
|
||||
msg = "provisioner.AuthorizeSSHRenew not implemented"
|
||||
case SSHRekeyMethod:
|
||||
var (
|
||||
cert *ssh.Certificate
|
||||
signOpts []SignOption
|
||||
)
|
||||
cert, signOpts, err = tt.p.AuthorizeSSHRekey(context.Background(), "")
|
||||
assert.Nil(t, cert)
|
||||
assert.Nil(t, signOpts)
|
||||
msg = "provisioner.AuthorizeSSHRekey not implemented"
|
||||
case SSHRevokeMethod:
|
||||
err = tt.p.AuthorizeSSHRevoke(context.Background(), "")
|
||||
msg = "provisioner.AuthorizeSSHRevoke not implemented"
|
||||
default:
|
||||
t.Errorf("unexpected method %s", tt.method)
|
||||
}
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
|
||||
assert.Equals(t, err.Error(), msg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ type SignOption interface{}
|
|||
// CertificateValidator is the interface used to validate a X.509 certificate.
|
||||
type CertificateValidator interface {
|
||||
SignOption
|
||||
Valid(crt *x509.Certificate) error
|
||||
Valid(cert *x509.Certificate, o Options) error
|
||||
}
|
||||
|
||||
// CertificateRequestValidator is the interface used to validate a X.509
|
||||
|
@ -106,7 +106,7 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
|
|||
return errors.New("certificate request cannot contain an empty common name")
|
||||
}
|
||||
if req.Subject.CommonName != string(v) {
|
||||
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v)
|
||||
return errors.Errorf("certificate request does not contain the valid common name; requested common name = %s, token subject = %s", req.Subject.CommonName, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -265,35 +265,32 @@ func newValidityValidator(min, max time.Duration) *validityValidator {
|
|||
|
||||
// Valid validates the certificate validity settings (notBefore/notAfter) and
|
||||
// and total duration.
|
||||
func (v *validityValidator) Valid(crt *x509.Certificate) error {
|
||||
func (v *validityValidator) Valid(cert *x509.Certificate, o Options) error {
|
||||
var (
|
||||
na = crt.NotAfter.Truncate(time.Second)
|
||||
nb = crt.NotBefore.Truncate(time.Second)
|
||||
na = cert.NotAfter.Truncate(time.Second)
|
||||
nb = cert.NotBefore.Truncate(time.Second)
|
||||
now = time.Now().Truncate(time.Second)
|
||||
)
|
||||
|
||||
// To not take into account the backdate, time.Now() will be used to
|
||||
// calculate the duration if NotBefore is in the past.
|
||||
var d time.Duration
|
||||
if now.After(nb) {
|
||||
d = na.Sub(now)
|
||||
} else {
|
||||
d = na.Sub(nb)
|
||||
}
|
||||
d := na.Sub(nb)
|
||||
|
||||
if na.Before(now) {
|
||||
return errors.Errorf("NotAfter: %v cannot be in the past", na)
|
||||
return errors.Errorf("notAfter cannot be in the past; na=%v", na)
|
||||
}
|
||||
if na.Before(nb) {
|
||||
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
|
||||
return errors.Errorf("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
|
||||
}
|
||||
if d < v.min {
|
||||
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
|
||||
d, v.min)
|
||||
}
|
||||
if d > v.max {
|
||||
// NOTE: this check is not "technically correct". We're allowing the max
|
||||
// duration of a cert to be "max + backdate" and not all certificates will
|
||||
// be backdated (e.g. if a user passes the NotBefore value then we do not
|
||||
// apply a backdate). This is good enough.
|
||||
if d > v.max+o.Backdate {
|
||||
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
|
||||
d, v.max)
|
||||
d, v.max+o.Backdate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -3,9 +3,10 @@ package provisioner
|
|||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -48,22 +49,22 @@ func Test_emailOnlyIdentity_Valid(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_defaultPublicKeyValidator_Valid(t *testing.T) {
|
||||
_shortRSA, err := pemutil.Read("./testdata/short-rsa.csr")
|
||||
_shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr")
|
||||
assert.FatalError(t, err)
|
||||
shortRSA, ok := _shortRSA.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
|
||||
_rsa, err := pemutil.Read("./testdata/rsa.csr")
|
||||
_rsa, err := pemutil.Read("./testdata/certs/rsa.csr")
|
||||
assert.FatalError(t, err)
|
||||
rsaCSR, ok := _rsa.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
|
||||
_ecdsa, err := pemutil.Read("./testdata/ecdsa.csr")
|
||||
_ecdsa, err := pemutil.Read("./testdata/certs/ecdsa.csr")
|
||||
assert.FatalError(t, err)
|
||||
ecdsaCSR, ok := _ecdsa.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
|
||||
_ed25519, err := pemutil.Read("./testdata/ed25519.csr")
|
||||
_ed25519, err := pemutil.Read("./testdata/certs/ed25519.csr")
|
||||
assert.FatalError(t, err)
|
||||
ed25519CSR, ok := _ed25519.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
|
@ -246,30 +247,191 @@ func Test_ipAddressesValidator_Valid(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_validityValidator_Valid(t *testing.T) {
|
||||
type fields struct {
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
type test struct {
|
||||
cert *x509.Certificate
|
||||
opts Options
|
||||
vv *validityValidator
|
||||
err error
|
||||
}
|
||||
type args struct {
|
||||
crt *x509.Certificate
|
||||
tests := map[string]func() test{
|
||||
"fail/notAfter-past": func() test {
|
||||
return test{
|
||||
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
|
||||
cert: &x509.Certificate{NotAfter: time.Now().Add(-5 * time.Minute)},
|
||||
opts: Options{},
|
||||
err: errors.New("notAfter cannot be in the past"),
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
},
|
||||
"fail/notBefore-after-notAfter": func() test {
|
||||
return test{
|
||||
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
|
||||
cert: &x509.Certificate{NotBefore: time.Now().Add(10 * time.Minute),
|
||||
NotAfter: time.Now().Add(5 * time.Minute)},
|
||||
opts: Options{},
|
||||
err: errors.New("notAfter cannot be before notBefore"),
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := &validityValidator{
|
||||
min: tt.fields.min,
|
||||
max: tt.fields.max,
|
||||
},
|
||||
"fail/duration-too-short": func() test {
|
||||
n := now()
|
||||
return test{
|
||||
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
|
||||
cert: &x509.Certificate{NotBefore: n,
|
||||
NotAfter: n.Add(3 * time.Minute)},
|
||||
opts: Options{},
|
||||
err: errors.New("is less than the authorized minimum certificate duration of "),
|
||||
}
|
||||
if err := v.Valid(tt.args.crt); (err != nil) != tt.wantErr {
|
||||
t.Errorf("validityValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
},
|
||||
"ok/duration-exactly-min": func() test {
|
||||
n := now()
|
||||
return test{
|
||||
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
|
||||
cert: &x509.Certificate{NotBefore: n,
|
||||
NotAfter: n.Add(5 * time.Minute)},
|
||||
opts: Options{},
|
||||
}
|
||||
},
|
||||
"fail/duration-too-great": func() test {
|
||||
n := now()
|
||||
return test{
|
||||
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
|
||||
cert: &x509.Certificate{NotBefore: n,
|
||||
NotAfter: n.Add(24*time.Hour + time.Second)},
|
||||
err: errors.New("is more than the authorized maximum certificate duration of "),
|
||||
}
|
||||
},
|
||||
"ok/duration-exactly-max": func() test {
|
||||
n := time.Now()
|
||||
return test{
|
||||
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
|
||||
cert: &x509.Certificate{NotBefore: n,
|
||||
NotAfter: n.Add(24 * time.Hour)},
|
||||
}
|
||||
},
|
||||
"ok/duration-exact-min-with-backdate": func() test {
|
||||
now := time.Now()
|
||||
cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(5 * time.Minute)}
|
||||
time.Sleep(time.Second)
|
||||
return test{
|
||||
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
|
||||
cert: cert,
|
||||
opts: Options{Backdate: time.Second},
|
||||
}
|
||||
},
|
||||
"ok/duration-exact-max-with-backdate": func() test {
|
||||
backdate := time.Second
|
||||
now := time.Now()
|
||||
cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(24*time.Hour + backdate)}
|
||||
time.Sleep(backdate)
|
||||
return test{
|
||||
vv: &validityValidator{5 * time.Minute, 24 * time.Hour},
|
||||
cert: cert,
|
||||
opts: Options{Backdate: backdate},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tt := run()
|
||||
if err := tt.vv.Valid(tt.cert, tt.opts); err != nil {
|
||||
if assert.NotNil(t, tt.err, fmt.Sprintf("expected no error, but got err = %s", err.Error())) {
|
||||
assert.True(t, strings.Contains(err.Error(), tt.err.Error()),
|
||||
fmt.Sprintf("want err = %s, but got err = %s", tt.err.Error(), err.Error()))
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tt.err, fmt.Sprintf("expected err = %s, but not <nil>", tt.err))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_profileDefaultDuration_Option(t *testing.T) {
|
||||
type test struct {
|
||||
so Options
|
||||
pdd profileDefaultDuration
|
||||
cert *x509.Certificate
|
||||
valid func(*x509.Certificate)
|
||||
}
|
||||
tests := map[string]func() test{
|
||||
"ok/notBefore-notAfter-duration-empty": func() test {
|
||||
return test{
|
||||
pdd: profileDefaultDuration(0),
|
||||
so: Options{},
|
||||
cert: new(x509.Certificate),
|
||||
valid: func(cert *x509.Certificate) {
|
||||
n := now()
|
||||
assert.True(t, n.After(cert.NotBefore))
|
||||
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
|
||||
|
||||
assert.True(t, n.Add(24*time.Hour).After(cert.NotAfter))
|
||||
assert.True(t, n.Add(24*time.Hour).Add(-1*time.Minute).Before(cert.NotAfter))
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/notBefore-set": func() test {
|
||||
nb := time.Now().Add(5 * time.Minute).UTC()
|
||||
return test{
|
||||
pdd: profileDefaultDuration(0),
|
||||
so: Options{NotBefore: NewTimeDuration(nb)},
|
||||
cert: new(x509.Certificate),
|
||||
valid: func(cert *x509.Certificate) {
|
||||
assert.Equals(t, cert.NotBefore, nb)
|
||||
assert.Equals(t, cert.NotAfter, nb.Add(24*time.Hour))
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/duration-set": func() test {
|
||||
d := 4 * time.Hour
|
||||
return test{
|
||||
pdd: profileDefaultDuration(d),
|
||||
so: Options{Backdate: time.Second},
|
||||
cert: new(x509.Certificate),
|
||||
valid: func(cert *x509.Certificate) {
|
||||
n := now()
|
||||
assert.True(t, n.After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n, cert.NotBefore))
|
||||
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
|
||||
|
||||
assert.True(t, n.Add(d).After(cert.NotAfter))
|
||||
assert.True(t, n.Add(d).Add(-1*time.Minute).Before(cert.NotAfter))
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/notAfter-set": func() test {
|
||||
na := now().Add(10 * time.Minute).UTC()
|
||||
return test{
|
||||
pdd: profileDefaultDuration(0),
|
||||
so: Options{NotAfter: NewTimeDuration(na)},
|
||||
cert: new(x509.Certificate),
|
||||
valid: func(cert *x509.Certificate) {
|
||||
n := now()
|
||||
assert.True(t, n.After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n, cert.NotBefore))
|
||||
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
|
||||
|
||||
assert.Equals(t, cert.NotAfter, na)
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/notBefore-and-notAfter-set": func() test {
|
||||
nb := time.Now().Add(5 * time.Minute).UTC()
|
||||
na := time.Now().Add(10 * time.Minute).UTC()
|
||||
d := 4 * time.Hour
|
||||
return test{
|
||||
pdd: profileDefaultDuration(d),
|
||||
so: Options{NotBefore: NewTimeDuration(nb), NotAfter: NewTimeDuration(na)},
|
||||
cert: new(x509.Certificate),
|
||||
valid: func(cert *x509.Certificate) {
|
||||
assert.Equals(t, cert.NotBefore, nb)
|
||||
assert.Equals(t, cert.NotAfter, na)
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tt := run()
|
||||
prof := &x509util.Leaf{}
|
||||
prof.SetSubject(tt.cert)
|
||||
assert.FatalError(t, tt.pdd.Option(tt.so)(prof), "unexpected error")
|
||||
tt.valid(prof.Subject())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -381,43 +543,3 @@ func Test_profileLimitDuration_Option(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_profileDefaultDuration_Option(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
|
||||
v := profileDefaultDuration(24 * time.Hour)
|
||||
type args struct {
|
||||
so Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
v profileDefaultDuration
|
||||
args args
|
||||
want *x509.Certificate
|
||||
}{
|
||||
{"default", v, args{Options{}}, &x509.Certificate{NotBefore: tm, NotAfter: tm.Add(24 * time.Hour)}},
|
||||
{"backdate", v, args{Options{Backdate: 1 * time.Minute}}, &x509.Certificate{NotBefore: tm.Add(-1 * time.Minute), NotAfter: tm.Add(24 * time.Hour)}},
|
||||
{"notBefore", v, args{Options{NotBefore: NewTimeDuration(tm.Add(10 * time.Second))}}, &x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(24*time.Hour + 10*time.Second)}},
|
||||
{"notAfter", v, args{Options{NotAfter: NewTimeDuration(tm.Add(1 * time.Hour))}}, &x509.Certificate{NotBefore: tm, NotAfter: tm.Add(1 * time.Hour)}},
|
||||
{"notBefore and notAfter", v, args{Options{NotBefore: NewTimeDuration(tm.Add(10 * time.Second)), NotAfter: NewTimeDuration(tm.Add(1 * time.Hour))}},
|
||||
&x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(1 * time.Hour)}},
|
||||
{"notBefore and backdate", v, args{Options{Backdate: 1 * time.Minute, NotBefore: NewTimeDuration(tm.Add(10 * time.Second))}},
|
||||
&x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(24*time.Hour + 10*time.Second)}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cert := &x509.Certificate{}
|
||||
profile := &x509util.Leaf{}
|
||||
profile.SetSubject(cert)
|
||||
|
||||
fn := tt.v.Option(tt.args.so)
|
||||
if err := fn(profile); err != nil {
|
||||
t.Errorf("profileDefaultDuration.Option() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(cert, tt.want) {
|
||||
t.Errorf("profileDefaultDuration.Option() = %v, \nwant %v", cert, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,29 +19,29 @@ const (
|
|||
SSHHostCert = "host"
|
||||
)
|
||||
|
||||
// SSHCertificateModifier is the interface used to change properties in an SSH
|
||||
// SSHCertModifier is the interface used to change properties in an SSH
|
||||
// certificate.
|
||||
type SSHCertificateModifier interface {
|
||||
type SSHCertModifier interface {
|
||||
SignOption
|
||||
Modify(cert *ssh.Certificate) error
|
||||
}
|
||||
|
||||
// SSHCertificateOptionModifier is the interface used to add custom options used
|
||||
// SSHCertOptionModifier is the interface used to add custom options used
|
||||
// to modify the SSH certificate.
|
||||
type SSHCertificateOptionModifier interface {
|
||||
type SSHCertOptionModifier interface {
|
||||
SignOption
|
||||
Option(o SSHOptions) SSHCertificateModifier
|
||||
Option(o SSHOptions) SSHCertModifier
|
||||
}
|
||||
|
||||
// SSHCertificateValidator is the interface used to validate an SSH certificate.
|
||||
type SSHCertificateValidator interface {
|
||||
// SSHCertValidator is the interface used to validate an SSH certificate.
|
||||
type SSHCertValidator interface {
|
||||
SignOption
|
||||
Valid(cert *ssh.Certificate) error
|
||||
Valid(cert *ssh.Certificate, opts SSHOptions) error
|
||||
}
|
||||
|
||||
// SSHCertificateOptionsValidator is the interface used to validate the custom
|
||||
// SSHCertOptionsValidator is the interface used to validate the custom
|
||||
// options used to modify the SSH certificate.
|
||||
type SSHCertificateOptionsValidator interface {
|
||||
type SSHCertOptionsValidator interface {
|
||||
SignOption
|
||||
Valid(got SSHOptions) error
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ func (o SSHOptions) Type() uint32 {
|
|||
return sshCertTypeUInt32(o.CertType)
|
||||
}
|
||||
|
||||
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate.
|
||||
// Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
|
||||
func (o SSHOptions) Modify(cert *ssh.Certificate) error {
|
||||
switch o.CertType {
|
||||
case "": // ignore
|
||||
|
@ -78,7 +78,7 @@ func (o SSHOptions) Modify(cert *ssh.Certificate) error {
|
|||
case SSHHostCert:
|
||||
cert.CertType = ssh.HostCert
|
||||
default:
|
||||
return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType)
|
||||
return errors.Errorf("ssh certificate has an unknown type - %s", o.CertType)
|
||||
}
|
||||
|
||||
cert.KeyId = o.KeyID
|
||||
|
@ -116,7 +116,7 @@ func (o SSHOptions) match(got SSHOptions) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertPrincipalsModifier is an SSHCertificateModifier that sets the
|
||||
// sshCertPrincipalsModifier is an SSHCertModifier that sets the
|
||||
// principals to the SSH certificate.
|
||||
type sshCertPrincipalsModifier []string
|
||||
|
||||
|
@ -126,16 +126,16 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given
|
||||
// sshCertKeyIDModifier is an SSHCertModifier that sets the given
|
||||
// Key ID in the SSH certificate.
|
||||
type sshCertificateKeyIDModifier string
|
||||
type sshCertKeyIDModifier string
|
||||
|
||||
func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error {
|
||||
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
|
||||
cert.KeyId = string(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertTypeModifier is an SSHCertificateModifier that sets the
|
||||
// sshCertTypeModifier is an SSHCertModifier that sets the
|
||||
// certificate type.
|
||||
type sshCertTypeModifier string
|
||||
|
||||
|
@ -145,30 +145,30 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the
|
||||
// sshCertValidAfterModifier is an SSHCertModifier that sets the
|
||||
// ValidAfter in the SSH certificate.
|
||||
type sshCertificateValidAfterModifier uint64
|
||||
type sshCertValidAfterModifier uint64
|
||||
|
||||
func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error {
|
||||
func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
|
||||
cert.ValidAfter = uint64(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the
|
||||
// sshCertValidBeforeModifier is an SSHCertModifier that sets the
|
||||
// ValidBefore in the SSH certificate.
|
||||
type sshCertificateValidBeforeModifier uint64
|
||||
type sshCertValidBeforeModifier uint64
|
||||
|
||||
func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error {
|
||||
func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error {
|
||||
cert.ValidBefore = uint64(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertificateDefaultModifier implements a SSHCertificateModifier that
|
||||
// sshCertDefaultsModifier implements a SSHCertModifier that
|
||||
// modifies the certificate with the given options if they are not set.
|
||||
type sshCertificateDefaultsModifier SSHOptions
|
||||
type sshCertDefaultsModifier SSHOptions
|
||||
|
||||
// Modify implements the SSHCertificateModifier interface.
|
||||
func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error {
|
||||
// Modify implements the SSHCertModifier interface.
|
||||
func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error {
|
||||
if cert.CertType == 0 {
|
||||
cert.CertType = sshCertTypeUInt32(m.CertType)
|
||||
}
|
||||
|
@ -184,7 +184,7 @@ func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets
|
||||
// sshDefaultExtensionModifier implements an SSHCertModifier that sets
|
||||
// the default extensions in an SSH certificate.
|
||||
type sshDefaultExtensionModifier struct{}
|
||||
|
||||
|
@ -208,14 +208,14 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
|
|||
}
|
||||
}
|
||||
|
||||
// sshDefaultDuration is an SSHCertificateModifier that sets the certificate
|
||||
// sshDefaultDuration is an SSHCertModifier that sets the certificate
|
||||
// ValidAfter and ValidBefore if they have not been set. It will fail if a
|
||||
// CertType has not been set or is not valid.
|
||||
type sshDefaultDuration struct {
|
||||
*Claimer
|
||||
}
|
||||
|
||||
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertificateModifier {
|
||||
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertModifier {
|
||||
return sshModifierFunc(func(cert *ssh.Certificate) error {
|
||||
d, err := m.DefaultSSHCertDuration(cert.CertType)
|
||||
if err != nil {
|
||||
|
@ -248,7 +248,7 @@ type sshLimitDuration struct {
|
|||
NotAfter time.Time
|
||||
}
|
||||
|
||||
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
|
||||
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertModifier {
|
||||
if m.NotAfter.IsZero() {
|
||||
defaultDuration := &sshDefaultDuration{m.Claimer}
|
||||
return defaultDuration.Option(o)
|
||||
|
@ -295,22 +295,22 @@ func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
|
|||
})
|
||||
}
|
||||
|
||||
// sshCertificateOptionsValidator validates the user SSHOptions with the ones
|
||||
// sshCertOptionsValidator validates the user SSHOptions with the ones
|
||||
// usually present in the token.
|
||||
type sshCertificateOptionsValidator SSHOptions
|
||||
type sshCertOptionsValidator SSHOptions
|
||||
|
||||
// Valid implements SSHCertificateOptionsValidator and returns nil if both
|
||||
// Valid implements SSHCertOptionsValidator and returns nil if both
|
||||
// SSHOptions match.
|
||||
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error {
|
||||
func (v sshCertOptionsValidator) Valid(got SSHOptions) error {
|
||||
want := SSHOptions(v)
|
||||
return want.match(got)
|
||||
}
|
||||
|
||||
type sshCertificateValidityValidator struct {
|
||||
type sshCertValidityValidator struct {
|
||||
*Claimer
|
||||
}
|
||||
|
||||
func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
|
||||
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SSHOptions) error {
|
||||
switch {
|
||||
case cert.ValidAfter == 0:
|
||||
return errors.New("ssh certificate validAfter cannot be 0")
|
||||
|
@ -336,31 +336,26 @@ func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
|
|||
|
||||
// To not take into account the backdate, time.Now() will be used to
|
||||
// calculate the duration if ValidAfter is in the past.
|
||||
var dur time.Duration
|
||||
if t := now().Unix(); t > int64(cert.ValidAfter) {
|
||||
dur = time.Duration(int64(cert.ValidBefore)-t) * time.Second
|
||||
} else {
|
||||
dur = time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
|
||||
}
|
||||
dur := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
|
||||
|
||||
switch {
|
||||
case dur < min:
|
||||
return errors.Errorf("requested duration of %s is less than minimum "+
|
||||
"accepted duration for selected provisioner of %s", dur, min)
|
||||
case dur > max:
|
||||
case dur > max+opts.Backdate:
|
||||
return errors.Errorf("requested duration of %s is greater than maximum "+
|
||||
"accepted duration for selected provisioner of %s", dur, max)
|
||||
"accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// sshCertificateDefaultValidator implements a simple validator for all the
|
||||
// sshCertDefaultValidator implements a simple validator for all the
|
||||
// fields in the SSH certificate.
|
||||
type sshCertificateDefaultValidator struct{}
|
||||
type sshCertDefaultValidator struct{}
|
||||
|
||||
// Valid returns an error if the given certificate does not contain the necessary fields.
|
||||
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
|
||||
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
|
||||
switch {
|
||||
case len(cert.Nonce) == 0:
|
||||
return errors.New("ssh certificate nonce cannot be empty")
|
||||
|
@ -395,7 +390,7 @@ func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
|
|||
type sshDefaultPublicKeyValidator struct{}
|
||||
|
||||
// Valid checks that certificate request common name matches the one configured.
|
||||
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error {
|
||||
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
|
||||
if cert.Key == nil {
|
||||
return errors.New("ssh certificate key cannot be nil")
|
||||
}
|
||||
|
@ -425,7 +420,7 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error {
|
|||
type sshCertKeyIDValidator string
|
||||
|
||||
// Valid returns an error if the given certificate does not contain the necessary fields.
|
||||
func (v sshCertKeyIDValidator) Valid(cert *ssh.Certificate) error {
|
||||
func (v sshCertKeyIDValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
|
||||
if string(v) != cert.KeyId {
|
||||
return errors.Errorf("invalid ssh certificate KeyId; want %s, but got %s", string(v), cert.KeyId)
|
||||
}
|
||||
|
|
|
@ -38,12 +38,463 @@ func TestSSHOptions_Type(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
|
||||
func TestSSHOptions_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
so *SSHOptions
|
||||
cert *ssh.Certificate
|
||||
valid func(*ssh.Certificate)
|
||||
err error
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
"fail/unexpected-cert-type": func() test {
|
||||
return test{
|
||||
so: &SSHOptions{CertType: "foo"},
|
||||
cert: new(ssh.Certificate),
|
||||
err: errors.Errorf("ssh certificate has an unknown type - foo"),
|
||||
}
|
||||
},
|
||||
"fail/validAfter-greater-validBefore": func() test {
|
||||
return test{
|
||||
so: &SSHOptions{CertType: "user"},
|
||||
cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)},
|
||||
err: errors.Errorf("ssh certificate valid after cannot be greater than valid before"),
|
||||
}
|
||||
},
|
||||
"ok/user-cert": func() test {
|
||||
return test{
|
||||
so: &SSHOptions{CertType: "user"},
|
||||
cert: new(ssh.Certificate),
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/host-cert": func() test {
|
||||
return test{
|
||||
so: &SSHOptions{CertType: "host"},
|
||||
cert: new(ssh.Certificate),
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok": func() test {
|
||||
va := time.Now().Add(5 * time.Minute)
|
||||
vb := time.Now().Add(1 * time.Hour)
|
||||
so := &SSHOptions{CertType: "host", KeyID: "foo", Principals: []string{"foo", "bar"},
|
||||
ValidAfter: NewTimeDuration(va), ValidBefore: NewTimeDuration(vb)}
|
||||
return test{
|
||||
so: so,
|
||||
cert: new(ssh.Certificate),
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
|
||||
assert.Equals(t, cert.KeyId, so.KeyID)
|
||||
assert.Equals(t, cert.ValidPrincipals, so.Principals)
|
||||
assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
|
||||
assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if err := tc.so.Modify(tc.cert); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
tc.valid(tc.cert)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHOptions_Match(t *testing.T) {
|
||||
type test struct {
|
||||
so SSHOptions
|
||||
cmp SSHOptions
|
||||
err error
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
"fail/cert-type": func() test {
|
||||
return test{
|
||||
so: SSHOptions{CertType: "foo"},
|
||||
cmp: SSHOptions{CertType: "bar"},
|
||||
err: errors.Errorf("ssh certificate type does not match - got bar, want foo"),
|
||||
}
|
||||
},
|
||||
"fail/pricipals": func() test {
|
||||
return test{
|
||||
so: SSHOptions{Principals: []string{"foo"}},
|
||||
cmp: SSHOptions{Principals: []string{"bar"}},
|
||||
err: errors.Errorf("ssh certificate principals does not match - got [bar], want [foo]"),
|
||||
}
|
||||
},
|
||||
"fail/validAfter": func() test {
|
||||
return test{
|
||||
so: SSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))},
|
||||
cmp: SSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))},
|
||||
err: errors.Errorf("ssh certificate valid after does not match"),
|
||||
}
|
||||
},
|
||||
"fail/validBefore": func() test {
|
||||
return test{
|
||||
so: SSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))},
|
||||
cmp: SSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))},
|
||||
err: errors.Errorf("ssh certificate valid before does not match"),
|
||||
}
|
||||
},
|
||||
"ok/original-empty": func() test {
|
||||
return test{
|
||||
so: SSHOptions{},
|
||||
cmp: SSHOptions{
|
||||
CertType: "foo",
|
||||
Principals: []string{"foo"},
|
||||
ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)),
|
||||
ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)),
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/cmp-empty": func() test {
|
||||
return test{
|
||||
cmp: SSHOptions{},
|
||||
so: SSHOptions{
|
||||
CertType: "foo",
|
||||
Principals: []string{"foo"},
|
||||
ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)),
|
||||
ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)),
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/equal": func() test {
|
||||
n := time.Now()
|
||||
va := NewTimeDuration(n.Add(1 * time.Minute))
|
||||
vb := NewTimeDuration(n.Add(5 * time.Minute))
|
||||
return test{
|
||||
cmp: SSHOptions{
|
||||
CertType: "foo",
|
||||
Principals: []string{"foo"},
|
||||
ValidAfter: va,
|
||||
ValidBefore: vb,
|
||||
},
|
||||
so: SSHOptions{
|
||||
CertType: "foo",
|
||||
Principals: []string{"foo"},
|
||||
ValidAfter: va,
|
||||
ValidBefore: vb,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if err := tc.so.match(tc.cmp); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertPrincipalsModifier
|
||||
cert *ssh.Certificate
|
||||
expected []string
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
"ok": func() test {
|
||||
a := []string{"foo", "bar"}
|
||||
return test{
|
||||
modifier: sshCertPrincipalsModifier(a),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: a,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
||||
assert.Equals(t, tc.cert.ValidPrincipals, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertKeyIDModifier
|
||||
cert *ssh.Certificate
|
||||
expected string
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
"ok": func() test {
|
||||
a := "foo"
|
||||
return test{
|
||||
modifier: sshCertKeyIDModifier(a),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: a,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
||||
assert.Equals(t, tc.cert.KeyId, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertTypeModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertTypeModifier
|
||||
cert *ssh.Certificate
|
||||
expected uint32
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
"ok/user": func() test {
|
||||
return test{
|
||||
modifier: sshCertTypeModifier("user"),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: ssh.UserCert,
|
||||
}
|
||||
},
|
||||
"ok/host": func() test {
|
||||
return test{
|
||||
modifier: sshCertTypeModifier("host"),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: ssh.HostCert,
|
||||
}
|
||||
},
|
||||
"ok/default": func() test {
|
||||
return test{
|
||||
modifier: sshCertTypeModifier("foo"),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: 0,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
||||
assert.Equals(t, tc.cert.CertType, uint32(tc.expected))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertValidAfterModifier
|
||||
cert *ssh.Certificate
|
||||
expected uint64
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
"ok": func() test {
|
||||
return test{
|
||||
modifier: sshCertValidAfterModifier(15),
|
||||
cert: new(ssh.Certificate),
|
||||
expected: 15,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
||||
assert.Equals(t, tc.cert.ValidAfter, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshCertDefaultsModifier
|
||||
cert *ssh.Certificate
|
||||
valid func(*ssh.Certificate)
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
"ok/changes": func() test {
|
||||
n := time.Now()
|
||||
va := NewTimeDuration(n.Add(1 * time.Minute))
|
||||
vb := NewTimeDuration(n.Add(5 * time.Minute))
|
||||
so := SSHOptions{
|
||||
Principals: []string{"foo", "bar"},
|
||||
CertType: "host",
|
||||
ValidAfter: va,
|
||||
ValidBefore: vb,
|
||||
}
|
||||
return test{
|
||||
modifier: sshCertDefaultsModifier(so),
|
||||
cert: new(ssh.Certificate),
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
assert.Equals(t, cert.ValidPrincipals, so.Principals)
|
||||
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
|
||||
assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
|
||||
assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/no-changes": func() test {
|
||||
n := time.Now()
|
||||
so := SSHOptions{
|
||||
Principals: []string{"foo", "bar"},
|
||||
CertType: "host",
|
||||
ValidAfter: NewTimeDuration(n.Add(15 * time.Minute)),
|
||||
ValidBefore: NewTimeDuration(n.Add(25 * time.Minute)),
|
||||
}
|
||||
return test{
|
||||
modifier: sshCertDefaultsModifier(so),
|
||||
cert: &ssh.Certificate{
|
||||
CertType: uint32(ssh.UserCert),
|
||||
ValidPrincipals: []string{"zap", "zoop"},
|
||||
ValidAfter: 15,
|
||||
ValidBefore: 25,
|
||||
},
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
assert.Equals(t, cert.ValidPrincipals, []string{"zap", "zoop"})
|
||||
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
|
||||
assert.Equals(t, cert.ValidAfter, uint64(15))
|
||||
assert.Equals(t, cert.ValidBefore, uint64(25))
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
||||
tc.valid(tc.cert)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
|
||||
type test struct {
|
||||
modifier sshDefaultExtensionModifier
|
||||
cert *ssh.Certificate
|
||||
valid func(*ssh.Certificate)
|
||||
err error
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
"fail/unexpected-cert-type": func() test {
|
||||
cert := &ssh.Certificate{CertType: 3}
|
||||
return test{
|
||||
modifier: sshDefaultExtensionModifier{},
|
||||
cert: cert,
|
||||
err: errors.New("ssh certificate type has not been set or is invalid"),
|
||||
}
|
||||
},
|
||||
"ok/host": func() test {
|
||||
cert := &ssh.Certificate{CertType: ssh.HostCert}
|
||||
return test{
|
||||
modifier: sshDefaultExtensionModifier{},
|
||||
cert: cert,
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
assert.Len(t, 0, cert.Extensions)
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/user/extensions-exists": func() test {
|
||||
cert := &ssh.Certificate{CertType: ssh.UserCert, Permissions: ssh.Permissions{Extensions: map[string]string{
|
||||
"foo": "bar",
|
||||
}}}
|
||||
return test{
|
||||
modifier: sshDefaultExtensionModifier{},
|
||||
cert: cert,
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
val, ok := cert.Extensions["foo"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "bar")
|
||||
|
||||
val, ok = cert.Extensions["permit-X11-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-agent-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-port-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-pty"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-user-rc"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/user/no-extensions": func() test {
|
||||
return test{
|
||||
modifier: sshDefaultExtensionModifier{},
|
||||
cert: &ssh.Certificate{CertType: ssh.UserCert},
|
||||
valid: func(cert *ssh.Certificate) {
|
||||
_, ok := cert.Extensions["foo"]
|
||||
assert.False(t, ok)
|
||||
|
||||
val, ok := cert.Extensions["permit-X11-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-agent-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-port-forwarding"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-pty"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
|
||||
val, ok = cert.Extensions["permit-user-rc"]
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, val, "")
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run()
|
||||
if err := tc.modifier.Modify(tc.cert); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
tc.valid(tc.cert)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sshCertDefaultValidator_Valid(t *testing.T) {
|
||||
pub, _, err := keys.GenerateDefaultKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
sshPub, err := ssh.NewPublicKey(pub)
|
||||
assert.FatalError(t, err)
|
||||
v := sshCertificateDefaultValidator{}
|
||||
v := sshCertDefaultValidator{}
|
||||
tests := []struct {
|
||||
name string
|
||||
cert *ssh.Certificate
|
||||
|
@ -208,7 +659,7 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := v.Valid(tt.cert); err != nil {
|
||||
if err := v.Valid(tt.cert, SSHOptions{}); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
|
@ -219,34 +670,39 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_sshCertificateValidityValidator(t *testing.T) {
|
||||
func Test_sshCertValidityValidator(t *testing.T) {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
v := sshCertificateValidityValidator{p.claimer}
|
||||
v := sshCertValidityValidator{p.claimer}
|
||||
n := now()
|
||||
tests := []struct {
|
||||
name string
|
||||
cert *ssh.Certificate
|
||||
opts SSHOptions
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"fail/validAfter-0",
|
||||
&ssh.Certificate{CertType: ssh.UserCert},
|
||||
SSHOptions{},
|
||||
errors.New("ssh certificate validAfter cannot be 0"),
|
||||
},
|
||||
{
|
||||
"fail/validBefore-in-past",
|
||||
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(-time.Minute).Unix())},
|
||||
SSHOptions{},
|
||||
errors.New("ssh certificate validBefore cannot be in the past"),
|
||||
},
|
||||
{
|
||||
"fail/validBefore-before-validAfter",
|
||||
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Add(5 * time.Minute).Unix()), ValidBefore: uint64(now().Add(3 * time.Minute).Unix())},
|
||||
SSHOptions{},
|
||||
errors.New("ssh certificate validBefore cannot be before validAfter"),
|
||||
},
|
||||
{
|
||||
"fail/cert-type-not-set",
|
||||
&ssh.Certificate{ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix())},
|
||||
SSHOptions{},
|
||||
errors.New("ssh certificate type has not been set"),
|
||||
},
|
||||
{
|
||||
|
@ -256,6 +712,7 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
|
|||
ValidAfter: uint64(now().Unix()),
|
||||
ValidBefore: uint64(now().Add(10 * time.Minute).Unix()),
|
||||
},
|
||||
SSHOptions{},
|
||||
errors.New("unknown ssh certificate type 3"),
|
||||
},
|
||||
{
|
||||
|
@ -265,8 +722,19 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
|
|||
ValidAfter: uint64(n.Unix()),
|
||||
ValidBefore: uint64(n.Add(4 * time.Minute).Unix()),
|
||||
},
|
||||
SSHOptions{Backdate: time.Second},
|
||||
errors.New("requested duration of 4m0s is less than minimum accepted duration for selected provisioner of 5m0s"),
|
||||
},
|
||||
{
|
||||
"ok/duration-exactly-min",
|
||||
&ssh.Certificate{
|
||||
CertType: 1,
|
||||
ValidAfter: uint64(n.Unix()),
|
||||
ValidBefore: uint64(n.Add(5 * time.Minute).Unix()),
|
||||
},
|
||||
SSHOptions{Backdate: time.Second},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"fail/duration>max",
|
||||
&ssh.Certificate{
|
||||
|
@ -274,7 +742,18 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
|
|||
ValidAfter: uint64(n.Unix()),
|
||||
ValidBefore: uint64(n.Add(48 * time.Hour).Unix()),
|
||||
},
|
||||
errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m0s"),
|
||||
SSHOptions{Backdate: time.Second},
|
||||
errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m1s"),
|
||||
},
|
||||
{
|
||||
"ok/duration-exactly-max",
|
||||
&ssh.Certificate{
|
||||
CertType: 1,
|
||||
ValidAfter: uint64(n.Unix()),
|
||||
ValidBefore: uint64(n.Add(24*time.Hour + time.Second).Unix()),
|
||||
},
|
||||
SSHOptions{Backdate: time.Second},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"ok",
|
||||
|
@ -283,12 +762,13 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
|
|||
ValidAfter: uint64(now().Unix()),
|
||||
ValidBefore: uint64(now().Add(8 * time.Hour).Unix()),
|
||||
},
|
||||
SSHOptions{Backdate: time.Second},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := v.Valid(tt.cert); err != nil {
|
||||
if err := v.Valid(tt.cert, tt.opts); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
|
@ -505,7 +985,7 @@ func Test_sshDefaultDuration_Option(t *testing.T) {
|
|||
{"host backdate", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert}},
|
||||
&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(30 * 24 * time.Hour)}, false},
|
||||
{"user validAfter", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(1 * time.Hour)}},
|
||||
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Minute), ValidBefore: unix(17 * time.Hour)}, false},
|
||||
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Hour), ValidBefore: unix(17 * time.Hour)}, false},
|
||||
{"user validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidBefore: unix(1 * time.Hour)}},
|
||||
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(time.Hour)}, false},
|
||||
{"host validAfter validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}},
|
||||
|
@ -541,7 +1021,7 @@ func Test_sshLimitDuration_Option(t *testing.T) {
|
|||
name string
|
||||
fields fields
|
||||
args args
|
||||
want SSHCertificateModifier
|
||||
want SSHCertModifier
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
|
|
|
@ -45,22 +45,22 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var mods []SSHCertificateModifier
|
||||
var validators []SSHCertificateValidator
|
||||
var mods []SSHCertModifier
|
||||
var validators []SSHCertValidator
|
||||
|
||||
for _, op := range signOpts {
|
||||
switch o := op.(type) {
|
||||
// modify the ssh.Certificate
|
||||
case SSHCertificateModifier:
|
||||
case SSHCertModifier:
|
||||
mods = append(mods, o)
|
||||
// modify the ssh.Certificate given the SSHOptions
|
||||
case SSHCertificateOptionModifier:
|
||||
case SSHCertOptionModifier:
|
||||
mods = append(mods, o.Option(opts))
|
||||
// validate the ssh.Certificate
|
||||
case SSHCertificateValidator:
|
||||
case SSHCertValidator:
|
||||
validators = append(validators, o)
|
||||
// validate the given SSHOptions
|
||||
case SSHCertificateOptionsValidator:
|
||||
case SSHCertOptionsValidator:
|
||||
if err := o.Valid(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
|
|||
|
||||
// User provisioners validators
|
||||
for _, v := range validators {
|
||||
if err := v.Valid(cert); err != nil {
|
||||
if err := v.Valid(cert, opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,11 +3,13 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
@ -99,33 +101,31 @@ func (p *SSHPOP) Init(config Config) error {
|
|||
// claims for case specific downstream parsing.
|
||||
// e.g. a Sign request will auth/validate different fields than a Revoke request.
|
||||
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
|
||||
sshCert, err := ExtractSSHPOPCert(token)
|
||||
sshCert, jwt, err := ExtractSSHPOPCert(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "authorizeToken ssh-pop")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err,
|
||||
"sshpop.authorizeToken; error extracting sshpop header from token")
|
||||
}
|
||||
|
||||
// Check for revocation.
|
||||
if isRevoked, err := p.db.IsSSHRevoked(strconv.FormatUint(sshCert.Serial, 10)); err != nil {
|
||||
return nil, errors.Wrap(err, "authorizeToken ssh-pop")
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"sshpop.authorizeToken; error checking checking sshpop cert revocation")
|
||||
} else if isRevoked {
|
||||
return nil, errors.New("authorizeToken ssh-pop: ssh certificate has been revoked")
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked")
|
||||
}
|
||||
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
}
|
||||
// Check validity period of the certificate.
|
||||
n := time.Now()
|
||||
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
|
||||
return nil, errors.New("sshpop certificate validAfter is in the future")
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
|
||||
}
|
||||
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) {
|
||||
return nil, errors.New("sshpop certificate validBefore is in the past")
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
|
||||
}
|
||||
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("ssh public key could not be cast to ssh CryptoPublicKey")
|
||||
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
|
||||
}
|
||||
pubKey := sshCryptoPubKey.CryptoPublicKey()
|
||||
|
||||
|
@ -146,7 +146,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.New("error: provisioner could could not verify the sshpop header certificate")
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate")
|
||||
}
|
||||
|
||||
// Using the ssh certificates key to validate the claims accomplishes two
|
||||
|
@ -156,7 +156,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
// 2. Asserts that the claims are valid - have not been tampered with.
|
||||
var claims sshPOPPayload
|
||||
if err = jwt.Claims(pubKey, &claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing claims")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; error parsing sshpop token claims")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
|
@ -165,16 +165,17 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
Issuer: p.Name,
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid token")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; invalid sshpop token")
|
||||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, audiences) {
|
||||
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token has invalid audience "+
|
||||
"claim (aud): expected %s, but got %s", audiences, claims.Audience)
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errors.New("token subject cannot be empty")
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token subject cannot be empty")
|
||||
}
|
||||
|
||||
claims.sshCert = sshCert
|
||||
|
@ -186,12 +187,13 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHRevoke)
|
||||
if err != nil {
|
||||
return err
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
|
||||
}
|
||||
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
|
||||
return errors.New("token subject must be equivalent to certificate serial number")
|
||||
return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " +
|
||||
"must be equivalent to sshpop certificate serial number")
|
||||
}
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthorizeSSHRenew validates the authorization token and extracts/validates
|
||||
|
@ -199,10 +201,10 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
|||
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHRenew)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
|
||||
}
|
||||
if claims.sshCert.CertType != ssh.HostCert {
|
||||
return nil, errors.New("sshpop AuthorizeSSHRenew: sshpop certificate must be a host ssh certificate")
|
||||
return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")
|
||||
}
|
||||
|
||||
return claims.sshCert, nil
|
||||
|
@ -214,51 +216,52 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
|
|||
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHRekey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
|
||||
}
|
||||
if claims.sshCert.CertType != ssh.HostCert {
|
||||
return nil, nil, errors.New("sshpop AuthorizeSSHRekey: sshpop certificate must be a host ssh certificate")
|
||||
return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")
|
||||
}
|
||||
return claims.sshCert, []SignOption{
|
||||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate
|
||||
// in the sshpop header. If the header is missing, an error is returned.
|
||||
func ExtractSSHPOPCert(token string) (*ssh.Certificate, error) {
|
||||
func ExtractSSHPOPCert(token string) (*ssh.Certificate, *jose.JSONWebToken, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
return nil, nil, errors.Wrapf(err, "extractSSHPOPCert; error parsing token")
|
||||
}
|
||||
|
||||
encodedSSHCert, ok := jwt.Headers[0].ExtraHeaders["sshpop"]
|
||||
if !ok {
|
||||
return nil, errors.New("token missing sshpop header")
|
||||
return nil, nil, errors.New("extractSSHPOPCert; token missing sshpop header")
|
||||
}
|
||||
encodedSSHCertStr, ok := encodedSSHCert.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("error unexpected type for sshpop header")
|
||||
return nil, nil, errors.Errorf("extractSSHPOPCert; error unexpected type for sshpop header: "+
|
||||
"want 'string', but got '%T'", encodedSSHCert)
|
||||
}
|
||||
sshCertBytes, err := base64.StdEncoding.DecodeString(encodedSSHCertStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error decoding sshpop header")
|
||||
return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error base64 decoding sshpop header")
|
||||
}
|
||||
sshPub, err := ssh.ParsePublicKey(sshCertBytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing ssh public key")
|
||||
return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error parsing ssh public key")
|
||||
}
|
||||
sshCert, ok := sshPub.(*ssh.Certificate)
|
||||
if !ok {
|
||||
return nil, errors.New("error converting ssh public key to ssh certificate")
|
||||
return nil, nil, errors.New("extractSSHPOPCert; error converting ssh public key to ssh certificate")
|
||||
}
|
||||
return sshCert, nil
|
||||
return sshCert, jwt, nil
|
||||
}
|
||||
|
||||
func bytesForSigning(cert *ssh.Certificate) []byte {
|
||||
|
|
684
authority/provisioner/sshpop_test.go
Normal file
684
authority/provisioner/sshpop_test.go
Normal file
|
@ -0,0 +1,684 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestSSHPOP_Getters(t *testing.T) {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
id := "sshpop/" + p.Name
|
||||
if got := p.GetID(); got != id {
|
||||
t.Errorf("SSHPOP.GetID() = %v, want %v", got, id)
|
||||
}
|
||||
if got := p.GetName(); got != p.Name {
|
||||
t.Errorf("SSHPOP.GetName() = %v, want %v", got, p.Name)
|
||||
}
|
||||
if got := p.GetType(); got != TypeSSHPOP {
|
||||
t.Errorf("SSHPOP.GetType() = %v, want %v", got, TypeSSHPOP)
|
||||
}
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
if kid != "" || key != "" || ok == true {
|
||||
t.Errorf("SSHPOP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, "", "", false)
|
||||
}
|
||||
}
|
||||
|
||||
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cert.Key, err = ssh.NewPublicKey(jwk.Public().Key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err = cert.SignCert(rand.Reader, signer); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return cert, jwk, nil
|
||||
}
|
||||
|
||||
func generateSSHPOPToken(p Interface, cert *ssh.Certificate, jwk *jose.JSONWebKey) (string, error) {
|
||||
return generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
}
|
||||
|
||||
func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
|
||||
assert.FatalError(t, err)
|
||||
signer, ok := key.(crypto.Signer)
|
||||
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
|
||||
sshSigner, err := ssh.NewSignerFromSigner(signer)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type test struct {
|
||||
p *SSHPOP
|
||||
token string
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/bad-token": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
|
||||
}
|
||||
},
|
||||
"fail/error-revoked-db-check": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, errors.New("force")
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusInternalServerError,
|
||||
err: errors.New("sshpop.authorizeToken; error checking checking sshpop cert revocation: force"),
|
||||
}
|
||||
},
|
||||
"fail/cert-already-revoked": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; sshpop certificate is revoked"),
|
||||
}
|
||||
},
|
||||
"fail/cert-not-yet-valid": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{
|
||||
CertType: ssh.UserCert,
|
||||
ValidAfter: uint64(time.Now().Add(time.Minute).Unix()),
|
||||
}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future"),
|
||||
}
|
||||
},
|
||||
"fail/cert-past-validity": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{
|
||||
CertType: ssh.UserCert,
|
||||
ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()),
|
||||
}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past"),
|
||||
}
|
||||
},
|
||||
"fail/no-signer-found": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate"),
|
||||
}
|
||||
},
|
||||
"fail/error-parsing-claims-bad-sig": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, otherJWK)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; error parsing sshpop token claims"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-claims-issuer": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; invalid sshpop token"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-audience": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), "invalid-aud", "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; sshpop token has invalid audience claim (aud)"),
|
||||
}
|
||||
},
|
||||
"fail/empty-subject": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.NotNil(t, claims)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
|
||||
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
|
||||
assert.FatalError(t, err)
|
||||
signer, ok := key.(crypto.Signer)
|
||||
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
|
||||
sshSigner, err := ssh.NewSignerFromSigner(signer)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type test struct {
|
||||
p *SSHPOP
|
||||
token string
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/bad-token": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.AuthorizeSSHRevoke: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
|
||||
}
|
||||
},
|
||||
"fail/subject-not-equal-serial": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusBadRequest,
|
||||
err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
|
||||
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
|
||||
assert.FatalError(t, err)
|
||||
userSigner, ok := key.(crypto.Signer)
|
||||
assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer")
|
||||
sshUserSigner, err := ssh.NewSignerFromSigner(userSigner)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
|
||||
assert.FatalError(t, err)
|
||||
hostSigner, ok := hostKey.(crypto.Signer)
|
||||
assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer")
|
||||
sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type test struct {
|
||||
p *SSHPOP
|
||||
token string
|
||||
cert *ssh.Certificate
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/bad-token": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.AuthorizeSSHRenew: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
|
||||
}
|
||||
},
|
||||
"fail/not-host-cert": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusBadRequest,
|
||||
err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
cert: cert,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
||||
key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key")
|
||||
assert.FatalError(t, err)
|
||||
userSigner, ok := key.(crypto.Signer)
|
||||
assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer")
|
||||
sshUserSigner, err := ssh.NewSignerFromSigner(userSigner)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
|
||||
assert.FatalError(t, err)
|
||||
hostSigner, ok := hostKey.(crypto.Signer)
|
||||
assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer")
|
||||
sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type test struct {
|
||||
p *SSHPOP
|
||||
token string
|
||||
cert *ssh.Certificate
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/bad-token": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.AuthorizeSSHRekey: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
|
||||
}
|
||||
},
|
||||
"fail/not-host-cert": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusBadRequest,
|
||||
err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
cert: cert,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Len(t, 3, opts)
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case *sshDefaultPublicKeyValidator:
|
||||
case *sshCertDefaultValidator:
|
||||
case *sshCertValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
}
|
||||
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHPOP_ExtractSSHPOPCert(t *testing.T) {
|
||||
hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
|
||||
assert.FatalError(t, err)
|
||||
hostSigner, ok := hostKey.(crypto.Signer)
|
||||
assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer")
|
||||
sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type test struct {
|
||||
token string
|
||||
cert *ssh.Certificate
|
||||
jwk *jose.JSONWebKey
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/bad-token": func(t *testing.T) test {
|
||||
return test{
|
||||
token: "foo",
|
||||
err: errors.New("extractSSHPOPCert; error parsing token"),
|
||||
}
|
||||
},
|
||||
"fail/sshpop-missing": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("sub", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
token: tok,
|
||||
err: errors.New("extractSSHPOPCert; token missing sshpop header"),
|
||||
}
|
||||
},
|
||||
"fail/wrong-sshpop-type": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error {
|
||||
so.WithHeader("sshpop", 12345)
|
||||
return nil
|
||||
})
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
token: tok,
|
||||
err: errors.New("extractSSHPOPCert; error unexpected type for sshpop header: "),
|
||||
}
|
||||
},
|
||||
"fail/base64decode-error": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error {
|
||||
so.WithHeader("sshpop", "!@#$%^&*")
|
||||
return nil
|
||||
})
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
token: tok,
|
||||
err: errors.New("extractSSHPOPCert; error base64 decoding sshpop header: illegal base64"),
|
||||
}
|
||||
},
|
||||
"fail/parsing-sshpop-pubkey": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error {
|
||||
so.WithHeader("sshpop", base64.StdEncoding.EncodeToString([]byte("foo")))
|
||||
return nil
|
||||
})
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
token: tok,
|
||||
err: errors.New("extractSSHPOPCert; error parsing ssh public key"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
|
||||
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
token: tok,
|
||||
jwk: jwk,
|
||||
cert: cert,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if cert, jwt, err := ExtractSSHPOPCert(tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
|
||||
assert.Equals(t, tc.jwk.KeyID, jwt.Headers[0].KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1
authority/provisioner/testdata/certs/ssh_host_ca_key.pub
vendored
Normal file
1
authority/provisioner/testdata/certs/ssh_host_ca_key.pub
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY=
|
1
authority/provisioner/testdata/certs/ssh_user_ca_key.pub
vendored
Normal file
1
authority/provisioner/testdata/certs/ssh_user_ca_key.pub
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y=
|
5
authority/provisioner/testdata/secrets/bar_host_ssh_key
vendored
Normal file
5
authority/provisioner/testdata/secrets/bar_host_ssh_key
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIHzAUYu3h8e1gL5ONGZo+lghJJa9rl1TvP2UlqDXazxvoAoGCCqGSM49
|
||||
AwEHoUQDQgAEOLScS+1Yzmqdyots9lSC0tzTSXUXEgyOD9wYrQ0BqnVZtBXlQw1p
|
||||
m3fnF/7Ehl6bD1YZWjrF1t+IBZQMq1uBBw==
|
||||
-----END EC PRIVATE KEY-----
|
5
authority/provisioner/testdata/secrets/foo_user_ssh_key
vendored
Normal file
5
authority/provisioner/testdata/secrets/foo_user_ssh_key
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEINWGD2xneE43YeytQzORItISxv6d/oH+9TXvDKHo6TyXoAoGCCqGSM49
|
||||
AwEHoUQDQgAEVK/EtXgVV7+7ppnQSjCtI5qb/gIGnQUF4i//F/JKKho7kRNyMDSn
|
||||
BP3kndiv8Yfxg4PsyIRY5ZofbEo5eJE6bg==
|
||||
-----END EC PRIVATE KEY-----
|
5
authority/provisioner/testdata/secrets/ssh_host_ca_key
vendored
Normal file
5
authority/provisioner/testdata/secrets/ssh_host_ca_key
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49
|
||||
AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+
|
||||
MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg==
|
||||
-----END EC PRIVATE KEY-----
|
5
authority/provisioner/testdata/secrets/ssh_user_ca_key
vendored
Normal file
5
authority/provisioner/testdata/secrets/ssh_user_ca_key
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49
|
||||
AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h
|
||||
YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg==
|
||||
-----END EC PRIVATE KEY-----
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -47,24 +48,6 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
func provisionerClaims() *Claims {
|
||||
ddr := false
|
||||
des := true
|
||||
return &Claims{
|
||||
MinTLSDur: &Duration{5 * time.Minute},
|
||||
MaxTLSDur: &Duration{24 * time.Hour},
|
||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||
DisableRenewal: &ddr,
|
||||
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &Duration{Duration: 4 * time.Hour},
|
||||
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &des,
|
||||
}
|
||||
}
|
||||
|
||||
const awsTestCertificate = `-----BEGIN CERTIFICATE-----
|
||||
MIICFTCCAX6gAwIBAgIRAKmbVVYAl/1XEqRfF3eJ97MwDQYJKoZIhvcNAQELBQAw
|
||||
GDEWMBQGA1UEAxMNQVdTIFRlc3QgQ2VydDAeFw0xOTA0MjQyMjU3MzlaFw0yOTA0
|
||||
|
@ -204,7 +187,7 @@ func generateJWK() (*JWK, error) {
|
|||
}
|
||||
|
||||
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
|
||||
fooPubB, err := ioutil.ReadFile("./testdata/foo.pub")
|
||||
fooPubB, err := ioutil.ReadFile("./testdata/certs/foo.pub")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -212,7 +195,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
barPubB, err := ioutil.ReadFile("./testdata/bar.pub")
|
||||
barPubB, err := ioutil.ReadFile("./testdata/certs/bar.pub")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -240,6 +223,46 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func generateSSHPOP() (*SSHPOP, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userB, err := ioutil.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userKey, _, _, _, err := ssh.ParseAuthorizedKey(userB)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hostB, err := ioutil.ReadFile("./testdata/certs/ssh_host_ca_key.pub")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hostKey, _, _, _, err := ssh.ParseAuthorizedKey(hostB)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SSHPOP{
|
||||
Name: name,
|
||||
Type: "SSHPOP",
|
||||
Claims: &globalProvisionerClaims,
|
||||
audiences: testAudiences,
|
||||
claimer: claimer,
|
||||
sshPubKeys: &SSHKeys{
|
||||
UserKeys: []ssh.PublicKey{userKey},
|
||||
HostKeys: []ssh.PublicKey{hostKey},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateX5C(root []byte) (*X5C, error) {
|
||||
if root == nil {
|
||||
root = []byte(`-----BEGIN CERTIFICATE-----
|
||||
|
@ -589,6 +612,13 @@ func withX5CHdr(certs []*x509.Certificate) tokOption {
|
|||
}
|
||||
}
|
||||
|
||||
func withSSHPOPFile(cert *ssh.Certificate) tokOption {
|
||||
return func(so *jose.SignerOptions) error {
|
||||
so.WithHeader("sshpop", base64.StdEncoding.EncodeToString(cert.Marshal()))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithType("JWT")
|
||||
|
@ -630,6 +660,24 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
|
|||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func generateX5CSSHToken(jwk *jose.JSONWebKey, claims *x5cPayload, tokOpts ...tokOption) (string, error) {
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithType("JWT")
|
||||
so.WithHeader("kid", jwk.KeyID)
|
||||
|
||||
for _, o := range tokOpts {
|
||||
if err := o(so); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func getK8sSAPayload() *k8sSAPayload {
|
||||
return &k8sSAPayload{
|
||||
Claims: jose.Claims{
|
||||
|
|
|
@ -4,9 +4,11 @@ import (
|
|||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
@ -121,19 +123,20 @@ func (p *X5C) Init(config Config) error {
|
|||
func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c token")
|
||||
}
|
||||
|
||||
verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{
|
||||
Roots: p.rootPool,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error verifying x5c certificate chain")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err,
|
||||
"x5c.authorizeToken; error verifying x5c certificate chain in token")
|
||||
}
|
||||
leaf := verifiedChains[0][0]
|
||||
|
||||
if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 {
|
||||
return nil, errors.New("certificate used to sign x5c token cannot be used for digital signature")
|
||||
return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")
|
||||
}
|
||||
|
||||
// Using the leaf certificates key to validate the claims accomplishes two
|
||||
|
@ -143,7 +146,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
|
|||
// 2. Asserts that the claims are valid - have not been tampered with.
|
||||
var claims x5cPayload
|
||||
if err = jwt.Claims(leaf.PublicKey, &claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing claims")
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c claims")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
|
@ -152,16 +155,17 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
|
|||
Issuer: p.Name,
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid token")
|
||||
return nil, errs.Wrapf(http.StatusUnauthorized, err, "x5c.authorizeToken; invalid x5c claims")
|
||||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, audiences) {
|
||||
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
||||
return nil, errs.Unauthorized("x5c.authorizeToken; x5c token has invalid audience "+
|
||||
"claim (aud); expected %s, but got %s", audiences, claims.Audience)
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errors.New("token subject cannot be empty")
|
||||
return nil, errs.Unauthorized("x5c.authorizeToken; x5c token subject cannot be empty")
|
||||
}
|
||||
|
||||
// Save the verified chains on the x5c payload object.
|
||||
|
@ -173,14 +177,14 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
|
|||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.audiences.Revoke)
|
||||
return err
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
|
||||
}
|
||||
|
||||
// NOTE: This is for backwards compatibility with older versions of cli
|
||||
|
@ -209,7 +213,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -217,22 +221,22 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID())
|
||||
}
|
||||
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
|
||||
}
|
||||
|
||||
if claims.Step == nil || claims.Step.SSH == nil {
|
||||
return nil, errors.New("authorization token must be an SSH provisioning token")
|
||||
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token")
|
||||
}
|
||||
|
||||
opts := claims.Step.SSH
|
||||
signOptions := []SignOption{
|
||||
// validates user's SSHOptions with the ones in the token
|
||||
sshCertificateOptionsValidator(*opts),
|
||||
sshCertOptionsValidator(*opts),
|
||||
}
|
||||
|
||||
// Add modifiers from custom claims
|
||||
|
@ -245,18 +249,18 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
}
|
||||
t := now()
|
||||
if !opts.ValidAfter.IsZero() {
|
||||
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
|
||||
signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
|
||||
}
|
||||
if !opts.ValidBefore.IsZero() {
|
||||
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
|
||||
signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
|
||||
}
|
||||
// Make sure to define the the KeyID
|
||||
if opts.KeyID == "" {
|
||||
signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject))
|
||||
signOptions = append(signOptions, sshCertKeyIDModifier(claims.Subject))
|
||||
}
|
||||
|
||||
// Default to a user certificate with no principals if not set
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert})
|
||||
signOptions = append(signOptions, sshCertDefaultsModifier{CertType: SSHUserCert})
|
||||
|
||||
return append(signOptions,
|
||||
// Set the default extensions.
|
||||
|
@ -268,8 +272,8 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate public key.
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -2,14 +2,16 @@ package provisioner
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -151,9 +153,15 @@ M46l92gdOozT
|
|||
}
|
||||
|
||||
func TestX5C_authorizeToken(t *testing.T) {
|
||||
x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
x5cJWK, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type test struct {
|
||||
p *X5C
|
||||
token string
|
||||
code int
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
|
@ -163,7 +171,8 @@ func TestX5C_authorizeToken(t *testing.T) {
|
|||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
err: errors.New("error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.authorizeToken; error parsing x5c token"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-cert-chain": func(t *testing.T) test {
|
||||
|
@ -190,7 +199,8 @@ a5wpg+9s6QIgHIW6L60F8klQX+EO3o0SBqLeNcaskA4oSZsKjEdpSGo=
|
|||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("error verifying x5c certificate chain: x509: certificate signed by unknown authority"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"),
|
||||
}
|
||||
},
|
||||
"fail/doubled-up-self-signed-cert": func(t *testing.T) test {
|
||||
|
@ -228,7 +238,8 @@ EXAHTA9L
|
|||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("error verifying x5c certificate chain: x509: certificate signed by unknown authority"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"),
|
||||
}
|
||||
},
|
||||
"fail/digital-signature-ext-required": func(t *testing.T) test {
|
||||
|
@ -269,7 +280,8 @@ lgsqsR63is+0YQ==
|
|||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("certificate used to sign x5c token cannot be used for digital signature"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature"),
|
||||
}
|
||||
},
|
||||
"fail/signature-does-not-match-x5c-pub-key": func(t *testing.T) test {
|
||||
|
@ -309,74 +321,58 @@ lgsqsR63is+0YQ==
|
|||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("error parsing claims: square/go-jose: error in cryptographic primitive"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.authorizeToken; error parsing x5c claims"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-issuer": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("", "foobar", testAudiences.Sign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk,
|
||||
withX5CHdr(certs))
|
||||
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
|
||||
withX5CHdr(x5cCerts))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("invalid token: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.authorizeToken; invalid x5c claims"),
|
||||
}
|
||||
},
|
||||
"fail/invalid-audience": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("", p.GetName(), "foobar", "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk,
|
||||
withX5CHdr(certs))
|
||||
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
|
||||
withX5CHdr(x5cCerts))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("invalid token: invalid audience claim (aud)"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.authorizeToken; x5c token has invalid audience claim (aud)"),
|
||||
}
|
||||
},
|
||||
"fail/empty-subject": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk,
|
||||
withX5CHdr(certs))
|
||||
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
|
||||
withX5CHdr(x5cCerts))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
err: errors.New("token subject cannot be empty"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.authorizeToken; x5c token subject cannot be empty"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk,
|
||||
withX5CHdr(certs))
|
||||
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
|
||||
withX5CHdr(x5cCerts))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
|
@ -389,6 +385,9 @@ lgsqsR63is+0YQ==
|
|||
tc := tt(t)
|
||||
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
|
@ -402,10 +401,15 @@ lgsqsR63is+0YQ==
|
|||
}
|
||||
|
||||
func TestX5C_AuthorizeSign(t *testing.T) {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type test struct {
|
||||
p *X5C
|
||||
token string
|
||||
ctx context.Context
|
||||
code int
|
||||
err error
|
||||
dns []string
|
||||
emails []string
|
||||
|
@ -418,56 +422,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
|||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
ctx: NewContextWithMethod(context.Background(), SignMethod),
|
||||
err: errors.New("error parsing token"),
|
||||
}
|
||||
},
|
||||
"fail/ssh/disabled": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
p.claimer.claims = provisionerClaims()
|
||||
*p.claimer.claims.EnableSSHCA = false
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk,
|
||||
withX5CHdr(certs))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
|
||||
token: tok,
|
||||
err: errors.Errorf("ssh ca is disabled for provisioner x5c/%s", p.GetName()),
|
||||
}
|
||||
},
|
||||
"fail/ssh/invalid-token": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk,
|
||||
withX5CHdr(certs))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
|
||||
token: tok,
|
||||
err: errors.New("authorization token must be an SSH provisioning token"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.AuthorizeSign: x5c.authorizeToken; error parsing x5c token"),
|
||||
}
|
||||
},
|
||||
"ok/empty-sans": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
|
||||
|
@ -476,7 +435,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
ctx: NewContextWithMethod(context.Background(), SignMethod),
|
||||
token: tok,
|
||||
dns: []string{"foo"},
|
||||
emails: []string{},
|
||||
|
@ -484,11 +442,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok/multi-sans": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "",
|
||||
|
@ -497,7 +450,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
ctx: NewContextWithMethod(context.Background(), SignMethod),
|
||||
token: tok,
|
||||
dns: []string{"foo"},
|
||||
emails: []string{"max@smallstep.com"},
|
||||
|
@ -508,8 +460,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if opts, err := tc.p.AuthorizeSign(tc.ctx, tc.token); err != nil {
|
||||
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
|
@ -554,126 +509,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
||||
_, fn := mockNow()
|
||||
defer fn()
|
||||
type test struct {
|
||||
p *X5C
|
||||
token string
|
||||
claims *x5cPayload
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/no-Step-claim": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
claims: new(x5cPayload),
|
||||
err: errors.New("authorization token must be an SSH provisioning token"),
|
||||
}
|
||||
},
|
||||
"fail/no-SSH-subattribute-in-claims": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
claims: &x5cPayload{Step: new(stepPayload)},
|
||||
err: errors.New("authorization token must be an SSH provisioning token"),
|
||||
}
|
||||
},
|
||||
"ok/with-claims": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
claims: &x5cPayload{
|
||||
Step: &stepPayload{SSH: &SSHOptions{
|
||||
CertType: SSHHostCert,
|
||||
Principals: []string{"max", "mariano", "alan"},
|
||||
ValidAfter: TimeDuration{d: 5 * time.Minute},
|
||||
ValidBefore: TimeDuration{d: 10 * time.Minute},
|
||||
}},
|
||||
Claims: jose.Claims{Subject: "foo"},
|
||||
chains: [][]*x509.Certificate{certs},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/without-claims": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
claims: &x5cPayload{
|
||||
Step: &stepPayload{SSH: &SSHOptions{}},
|
||||
Claims: jose.Claims{Subject: "foo"},
|
||||
chains: [][]*x509.Certificate{certs},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if opts, err := tc.p.AuthorizeSSHSign(context.TODO(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
if assert.NotNil(t, opts) {
|
||||
tot := 0
|
||||
nw := now()
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case sshCertificateOptionsValidator:
|
||||
tc.claims.Step.SSH.ValidAfter.t = time.Time{}
|
||||
tc.claims.Step.SSH.ValidBefore.t = time.Time{}
|
||||
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
|
||||
case sshCertificateKeyIDModifier:
|
||||
assert.Equals(t, string(v), "foo")
|
||||
case sshCertTypeModifier:
|
||||
assert.Equals(t, string(v), tc.claims.Step.SSH.CertType)
|
||||
case sshCertPrincipalsModifier:
|
||||
assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals)
|
||||
case sshCertificateValidAfterModifier:
|
||||
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
|
||||
case sshCertificateValidBeforeModifier:
|
||||
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
|
||||
case sshCertificateDefaultsModifier:
|
||||
assert.Equals(t, SSHOptions(v), SSHOptions{CertType: SSHUserCert})
|
||||
case *sshLimitDuration:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
assert.Equals(t, v.NotAfter, tc.claims.chains[0][0].NotAfter)
|
||||
case *sshCertificateValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator,
|
||||
*sshCertificateDefaultValidator:
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
tot++
|
||||
}
|
||||
if len(tc.claims.Step.SSH.CertType) > 0 {
|
||||
assert.Equals(t, tot, 12)
|
||||
} else {
|
||||
assert.Equals(t, tot, 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestX5C_AuthorizeRevoke(t *testing.T) {
|
||||
type test struct {
|
||||
p *X5C
|
||||
token string
|
||||
code int
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
|
@ -683,13 +523,14 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
|
|||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
err: errors.New("error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.AuthorizeRevoke: x5c.authorizeToken; error parsing x5c token"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt")
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.ParseKey("./testdata/x5c-leaf.key")
|
||||
jwk, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
|
@ -707,8 +548,11 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil {
|
||||
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
|
@ -719,33 +563,248 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestX5C_AuthorizeRenew(t *testing.T) {
|
||||
p1, err := generateX5C(nil)
|
||||
type test struct {
|
||||
p *X5C
|
||||
code int
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/renew-disabled": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
p.Claims = &Claims{DisableRenewal: &disable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
return test{
|
||||
p: p,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID()),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
||||
x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
|
||||
assert.FatalError(t, err)
|
||||
x5cJWK, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_, fn := mockNow()
|
||||
defer fn()
|
||||
type test struct {
|
||||
p *X5C
|
||||
token string
|
||||
claims *x5cPayload
|
||||
code int
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) test{
|
||||
"fail/sshCA-disabled": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
// disable sshCA
|
||||
enable := false
|
||||
p.Claims = &Claims{EnableSSHCA: &enable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID()),
|
||||
}
|
||||
},
|
||||
"fail/invalid-token": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.AuthorizeSSHSign: x5c.authorizeToken; error parsing x5c token"),
|
||||
}
|
||||
},
|
||||
"fail/no-Step-claim": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHSign[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), x5cJWK,
|
||||
withX5CHdr(x5cCerts))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"),
|
||||
}
|
||||
},
|
||||
"fail/no-SSH-subattribute-in-claims": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
id, err := randutil.ASCII(64)
|
||||
assert.FatalError(t, err)
|
||||
now := time.Now()
|
||||
claims := &x5cPayload{
|
||||
Claims: jose.Claims{
|
||||
ID: id,
|
||||
Subject: "foo",
|
||||
Issuer: p.GetName(),
|
||||
IssuedAt: jose.NewNumericDate(now),
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
Audience: []string{testAudiences.SSHSign[0]},
|
||||
},
|
||||
Step: &stepPayload{},
|
||||
}
|
||||
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"),
|
||||
}
|
||||
},
|
||||
"ok/with-claims": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
id, err := randutil.ASCII(64)
|
||||
assert.FatalError(t, err)
|
||||
now := time.Now()
|
||||
claims := &x5cPayload{
|
||||
Claims: jose.Claims{
|
||||
ID: id,
|
||||
Subject: "foo",
|
||||
Issuer: p.GetName(),
|
||||
IssuedAt: jose.NewNumericDate(now),
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
Audience: []string{testAudiences.SSHSign[0]},
|
||||
},
|
||||
Step: &stepPayload{SSH: &SSHOptions{
|
||||
CertType: SSHHostCert,
|
||||
Principals: []string{"max", "mariano", "alan"},
|
||||
ValidAfter: TimeDuration{d: 5 * time.Minute},
|
||||
ValidBefore: TimeDuration{d: 10 * time.Minute},
|
||||
}},
|
||||
}
|
||||
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
claims: claims,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
"ok/without-claims": func(t *testing.T) test {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
id, err := randutil.ASCII(64)
|
||||
assert.FatalError(t, err)
|
||||
now := time.Now()
|
||||
claims := &x5cPayload{
|
||||
Claims: jose.Claims{
|
||||
ID: id,
|
||||
Subject: "foo",
|
||||
Issuer: p.GetName(),
|
||||
IssuedAt: jose.NewNumericDate(now),
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
Audience: []string{testAudiences.SSHSign[0]},
|
||||
},
|
||||
Step: &stepPayload{SSH: &SSHOptions{}},
|
||||
}
|
||||
tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts))
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
claims: claims,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
if assert.NotNil(t, opts) {
|
||||
tot := 0
|
||||
nw := now()
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case sshCertOptionsValidator:
|
||||
tc.claims.Step.SSH.ValidAfter.t = time.Time{}
|
||||
tc.claims.Step.SSH.ValidBefore.t = time.Time{}
|
||||
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
|
||||
case sshCertKeyIDModifier:
|
||||
assert.Equals(t, string(v), "foo")
|
||||
case sshCertTypeModifier:
|
||||
assert.Equals(t, string(v), tc.claims.Step.SSH.CertType)
|
||||
case sshCertPrincipalsModifier:
|
||||
assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals)
|
||||
case sshCertValidAfterModifier:
|
||||
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
|
||||
case sshCertValidBeforeModifier:
|
||||
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
|
||||
case sshCertDefaultsModifier:
|
||||
assert.Equals(t, SSHOptions(v), SSHOptions{CertType: SSHUserCert})
|
||||
case *sshLimitDuration:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
|
||||
case *sshCertValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator,
|
||||
*sshCertDefaultValidator:
|
||||
case sshCertKeyIDValidator:
|
||||
assert.Equals(t, string(v), "foo")
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
tot++
|
||||
}
|
||||
if len(tc.claims.Step.SSH.CertType) > 0 {
|
||||
assert.Equals(t, tot, 13)
|
||||
} else {
|
||||
assert.Equals(t, tot, 9)
|
||||
}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *X5C
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("X5C.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,18 +2,16 @@ package authority
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
||||
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||
key, ok := a.provisioners.LoadEncryptedKey(kid)
|
||||
if !ok {
|
||||
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
return "", errs.NotFound("encrypted key with kid %s was not found", kid)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
@ -30,8 +28,7 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
|
|||
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
|
||||
p, ok := a.provisioners.LoadByCertificate(crt)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("provisioner not found"),
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
return nil, errs.NotFound("provisioner not found")
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
@ -40,8 +37,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
|
|||
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
p, ok := a.provisioners.Load(id)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("provisioner not found"),
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
return nil, errs.NotFound("provisioner not found")
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
|
|
@ -7,13 +7,15 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
func TestGetEncryptedKey(t *testing.T) {
|
||||
type ek struct {
|
||||
a *Authority
|
||||
kid string
|
||||
err *apiError
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(t *testing.T) *ek{
|
||||
"ok": func(t *testing.T) *ek {
|
||||
|
@ -34,8 +36,8 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
return &ek{
|
||||
a: a,
|
||||
kid: "foo",
|
||||
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
|
||||
http.StatusNotFound, apiCtx{}},
|
||||
err: errors.New("encrypted key with kid foo was not found"),
|
||||
code: http.StatusNotFound,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -47,14 +49,10 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
ek, err := tc.a.GetEncryptedKey(tc.kid)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch v := err.(type) {
|
||||
case *apiError:
|
||||
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.code, tc.err.code)
|
||||
assert.Equals(t, v.context, tc.err.context)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
}
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
|
@ -72,7 +70,8 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
func TestGetProvisioners(t *testing.T) {
|
||||
type gp struct {
|
||||
a *Authority
|
||||
err *apiError
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(t *testing.T) *gp{
|
||||
"ok": func(t *testing.T) *gp {
|
||||
|
@ -91,14 +90,10 @@ func TestGetProvisioners(t *testing.T) {
|
|||
ps, next, err := tc.a.GetProvisioners("", 0)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch v := err.(type) {
|
||||
case *apiError:
|
||||
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.code, tc.err.code)
|
||||
assert.Equals(t, v.context, tc.err.context)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
}
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
|
|
|
@ -2,23 +2,20 @@ package authority
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// Root returns the certificate corresponding to the given SHA sum argument.
|
||||
func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
||||
val, ok := a.certificates.Load(sum)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum),
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
return nil, errs.NotFound("certificate with fingerprint %s was not found", sum)
|
||||
}
|
||||
|
||||
crt, ok := val.(*x509.Certificate)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
return nil, errs.InternalServer("stored value is not a *x509.Certificate")
|
||||
}
|
||||
return crt, nil
|
||||
}
|
||||
|
@ -52,8 +49,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error)
|
|||
crt, ok := v.(*x509.Certificate)
|
||||
if !ok {
|
||||
federation = nil
|
||||
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
err = errs.InternalServer("stored value is not a *x509.Certificate")
|
||||
return false
|
||||
}
|
||||
federation = append(federation, crt)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
)
|
||||
|
||||
|
@ -17,11 +18,12 @@ func TestRoot(t *testing.T) {
|
|||
|
||||
tests := map[string]struct {
|
||||
sum string
|
||||
err *apiError
|
||||
err error
|
||||
code int
|
||||
}{
|
||||
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}},
|
||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}},
|
||||
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
|
||||
"not-found": {"foo", errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound},
|
||||
"invalid-stored-certificate": {"invaliddata", errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError},
|
||||
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil, http.StatusOK},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
|
@ -29,14 +31,10 @@ func TestRoot(t *testing.T) {
|
|||
crt, err := a.Root(tc.sum)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch v := err.(type) {
|
||||
case *apiError:
|
||||
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.code, tc.err.code)
|
||||
assert.Equals(t, v.context, tc.err.context)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
}
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
|
|
276
authority/ssh.go
276
authority/ssh.go
|
@ -122,10 +122,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
|
|||
// GetSSHConfig returns rendered templates for clients (user) or servers (host).
|
||||
func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
|
||||
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("getSSHConfig: ssh is not configured"),
|
||||
code: http.StatusNotFound,
|
||||
}
|
||||
return nil, errs.NotFound("getSSHConfig: ssh is not configured")
|
||||
}
|
||||
|
||||
var ts []templates.Template
|
||||
|
@ -139,10 +136,7 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
|
|||
ts = a.config.Templates.SSH.Host
|
||||
}
|
||||
default:
|
||||
return nil, &apiError{
|
||||
err: errors.Errorf("getSSHConfig: type %s is not valid", typ),
|
||||
code: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ)
|
||||
}
|
||||
|
||||
// Merge user and default data
|
||||
|
@ -174,7 +168,8 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
|
|||
// hostname.
|
||||
func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) {
|
||||
if a.sshBastionFunc != nil {
|
||||
return a.sshBastionFunc(user, hostname)
|
||||
bs, err := a.sshBastionFunc(user, hostname)
|
||||
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
|
||||
}
|
||||
if a.config.SSH != nil {
|
||||
if a.config.SSH.Bastion != nil && a.config.SSH.Bastion.Hostname != "" {
|
||||
|
@ -182,32 +177,13 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error
|
|||
}
|
||||
return nil, nil
|
||||
}
|
||||
return nil, &apiError{
|
||||
err: errors.New("getSSHBastion: ssh is not configured"),
|
||||
code: http.StatusNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeSSHSign loads the provisioner from the token, checks that it has not
|
||||
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a
|
||||
// list of methods to apply to the signing flow.
|
||||
func (a *Authority) authorizeSSHSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
var errContext = apiCtx{"ott": ott}
|
||||
p, err := a.authorizeToken(ctx, ott)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "authorizeSSHSign"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
opts, err := p.AuthorizeSSHSign(ctx, ott)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "authorizeSSHSign"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
return opts, nil
|
||||
return nil, errs.NotFound("authority.GetSSHBastion; ssh is not configured")
|
||||
}
|
||||
|
||||
// SignSSH creates a signed SSH certificate with the given public key and options.
|
||||
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
var mods []provisioner.SSHCertificateModifier
|
||||
var validators []provisioner.SSHCertificateValidator
|
||||
var mods []provisioner.SSHCertModifier
|
||||
var validators []provisioner.SSHCertValidator
|
||||
|
||||
// Set backdate with the configured value
|
||||
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
|
||||
|
@ -215,38 +191,32 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
|||
for _, op := range signOpts {
|
||||
switch o := op.(type) {
|
||||
// modify the ssh.Certificate
|
||||
case provisioner.SSHCertificateModifier:
|
||||
case provisioner.SSHCertModifier:
|
||||
mods = append(mods, o)
|
||||
// modify the ssh.Certificate given the SSHOptions
|
||||
case provisioner.SSHCertificateOptionModifier:
|
||||
case provisioner.SSHCertOptionModifier:
|
||||
mods = append(mods, o.Option(opts))
|
||||
// validate the ssh.Certificate
|
||||
case provisioner.SSHCertificateValidator:
|
||||
case provisioner.SSHCertValidator:
|
||||
validators = append(validators, o)
|
||||
// validate the given SSHOptions
|
||||
case provisioner.SSHCertificateOptionsValidator:
|
||||
case provisioner.SSHCertOptionsValidator:
|
||||
if err := o.Valid(opts); err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
||||
}
|
||||
default:
|
||||
return nil, &apiError{
|
||||
err: errors.Errorf("signSSH: invalid extra option type %T", o),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.InternalServer("signSSH: invalid extra option type %T", o)
|
||||
}
|
||||
}
|
||||
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH")
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSH: error reading random number"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error reading random number")
|
||||
}
|
||||
|
||||
// Build base certificate with the key and some random values
|
||||
|
@ -258,13 +228,13 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
|||
|
||||
// Use opts to modify the certificate
|
||||
if err := opts.Modify(cert); err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
||||
}
|
||||
|
||||
// Use provisioner modifiers
|
||||
for _, m := range mods {
|
||||
if err := m.Modify(cert); err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -273,25 +243,16 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
|||
switch cert.CertType {
|
||||
case ssh.UserCert:
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSH: user certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAUserCertSignKey
|
||||
case ssh.HostCert:
|
||||
if a.sshCAHostCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSH: host certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAHostCertSignKey
|
||||
default:
|
||||
return nil, &apiError{
|
||||
err: errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType)
|
||||
}
|
||||
cert.SignatureKey = signer.PublicKey()
|
||||
|
||||
|
@ -302,71 +263,38 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
|||
// Sign the certificate
|
||||
sig, err := signer.Sign(rand.Reader, data)
|
||||
if err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSH: error signing certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate")
|
||||
}
|
||||
cert.Signature = sig
|
||||
|
||||
// User provisioners validators
|
||||
for _, v := range validators {
|
||||
if err := v.Valid(cert); err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||
if err := v.Valid(cert, opts); err != nil {
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
||||
}
|
||||
}
|
||||
|
||||
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSH: error storing certificate in db"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error storing certificate in db")
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// authorizeSSHRenew authorizes an SSH certificate renewal request, by
|
||||
// validating the contents of an SSHPOP token.
|
||||
func (a *Authority) authorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
|
||||
errContext := map[string]interface{}{"ott": token}
|
||||
|
||||
p, err := a.authorizeToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "authorizeSSHRenew"),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
cert, err := p.AuthorizeSSHRenew(ctx, token)
|
||||
if err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "authorizeSSHRenew"),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
|
||||
func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "renewSSH: error reading random number"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error reading random number")
|
||||
}
|
||||
|
||||
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
|
||||
return nil, errors.New("rewnewSSH: cannot renew certificate without validity period")
|
||||
return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period")
|
||||
}
|
||||
|
||||
backdate := a.config.AuthorityConfig.Backdate.Duration
|
||||
|
@ -393,25 +321,16 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
|
|||
switch cert.CertType {
|
||||
case ssh.UserCert:
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("renewSSH: user certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
return nil, errs.NotImplemented("renewSSH: user certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAUserCertSignKey
|
||||
case ssh.HostCert:
|
||||
if a.sshCAHostCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("renewSSH: host certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
return nil, errs.NotImplemented("renewSSH: host certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAHostCertSignKey
|
||||
default:
|
||||
return nil, &apiError{
|
||||
err: errors.Errorf("renewSSH: unexpected ssh certificate type: %d", cert.CertType),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.InternalServer("renewSSH: unexpected ssh certificate type: %d", cert.CertType)
|
||||
}
|
||||
cert.SignatureKey = signer.PublicKey()
|
||||
|
||||
|
@ -422,79 +341,43 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
|
|||
// Sign the certificate
|
||||
sig, err := signer.Sign(rand.Reader, data)
|
||||
if err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "renewSSH: error signing certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error signing certificate")
|
||||
}
|
||||
cert.Signature = sig
|
||||
|
||||
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "renewSSH: error storing certificate in db"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db")
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// authorizeSSHRekey authorizes an SSH certificate rekey request, by
|
||||
// validating the contents of an SSHPOP token.
|
||||
func (a *Authority) authorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) {
|
||||
errContext := map[string]interface{}{"ott": token}
|
||||
|
||||
p, err := a.authorizeToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, &apiError{
|
||||
err: errors.Wrap(err, "authorizeSSHRenew"),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
cert, opts, err := p.AuthorizeSSHRekey(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, &apiError{
|
||||
err: errors.Wrap(err, "authorizeSSHRekey"),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
return cert, opts, nil
|
||||
}
|
||||
|
||||
// RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
|
||||
func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
var validators []provisioner.SSHCertificateValidator
|
||||
var validators []provisioner.SSHCertValidator
|
||||
|
||||
for _, op := range signOpts {
|
||||
switch o := op.(type) {
|
||||
// validate the ssh.Certificate
|
||||
case provisioner.SSHCertificateValidator:
|
||||
case provisioner.SSHCertValidator:
|
||||
validators = append(validators, o)
|
||||
default:
|
||||
return nil, &apiError{
|
||||
err: errors.Errorf("rekeySSH: invalid extra option type %T", o),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.InternalServer("rekeySSH; invalid extra option type %T", o)
|
||||
}
|
||||
}
|
||||
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH")
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "rekeySSH: error reading random number"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error reading random number")
|
||||
}
|
||||
|
||||
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
|
||||
return nil, errors.New("rekeySSH: cannot rekey certificate without validity period")
|
||||
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
|
||||
}
|
||||
|
||||
backdate := a.config.AuthorityConfig.Backdate.Duration
|
||||
|
@ -521,25 +404,16 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
|
|||
switch cert.CertType {
|
||||
case ssh.UserCert:
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("rekeySSH: user certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
return nil, errs.NotImplemented("rekeySSH; user certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAUserCertSignKey
|
||||
case ssh.HostCert:
|
||||
if a.sshCAHostCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("rekeySSH: host certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
return nil, errs.NotImplemented("rekeySSH; host certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAHostCertSignKey
|
||||
default:
|
||||
return nil, &apiError{
|
||||
err: errors.Errorf("rekeySSH: unexpected ssh certificate type: %d", cert.CertType),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)
|
||||
}
|
||||
cert.SignatureKey = signer.PublicKey()
|
||||
|
||||
|
@ -547,80 +421,47 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
|
|||
data := cert.Marshal()
|
||||
data = data[:len(data)-4]
|
||||
|
||||
// Sign the certificate
|
||||
// Sign the certificate.
|
||||
sig, err := signer.Sign(rand.Reader, data)
|
||||
if err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "rekeySSH: error signing certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error signing certificate")
|
||||
}
|
||||
cert.Signature = sig
|
||||
|
||||
// User provisioners validators
|
||||
// Apply validators from provisioner.
|
||||
for _, v := range validators {
|
||||
if err := v.Valid(cert); err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||
if err := v.Valid(cert, provisioner.SSHOptions{Backdate: backdate}); err != nil {
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH")
|
||||
}
|
||||
}
|
||||
|
||||
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "rekeySSH: error storing certificate in db"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db")
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// authorizeSSHRevoke authorizes an SSH certificate revoke request, by
|
||||
// validating the contents of an SSHPOP token.
|
||||
func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
errContext := map[string]interface{}{"ott": token}
|
||||
|
||||
p, err := a.authorizeToken(ctx, token)
|
||||
if err != nil {
|
||||
return &apiError{errors.Wrap(err, "authorizeSSHRevoke"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
if err = p.AuthorizeSSHRevoke(ctx, token); err != nil {
|
||||
return &apiError{errors.Wrap(err, "authorizeSSHRevoke"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignSSHAddUser signs a certificate that provisions a new user in a server.
|
||||
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSHAddUser: user certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
|
||||
}
|
||||
if subject.CertType != ssh.UserCert {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSHAddUser: certificate is not a user certificate"),
|
||||
code: http.StatusForbidden,
|
||||
}
|
||||
return nil, errs.Forbidden("signSSHAddUser: certificate is not a user certificate")
|
||||
}
|
||||
if len(subject.ValidPrincipals) != 1 {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSHAddUser: certificate does not have only one principal"),
|
||||
code: http.StatusForbidden,
|
||||
}
|
||||
return nil, errs.Forbidden("signSSHAddUser: certificate does not have only one principal")
|
||||
}
|
||||
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser")
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSHAddUser: error reading random number"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number")
|
||||
}
|
||||
|
||||
signer := a.sshCAUserCertSignKey
|
||||
|
@ -656,10 +497,7 @@ func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate)
|
|||
cert.Signature = sig
|
||||
|
||||
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSHAddUser: error storing certificate in db"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db")
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
|
@ -691,14 +529,12 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st
|
|||
// GetSSHHosts returns a list of valid host principals.
|
||||
func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||
if a.sshGetHostsFunc != nil {
|
||||
return a.sshGetHostsFunc(cert)
|
||||
hosts, err := a.sshGetHostsFunc(cert)
|
||||
return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
|
||||
}
|
||||
hostnames, err := a.db.GetSSHHostPrincipals()
|
||||
if err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "getSSHHosts"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
|
||||
}
|
||||
|
||||
hosts := make([]sshutil.Host, len(hostnames))
|
||||
|
|
|
@ -5,8 +5,10 @@ import (
|
|||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -15,6 +17,8 @@ import (
|
|||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/sshutil"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
@ -58,7 +62,7 @@ func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error {
|
|||
|
||||
type sshTestCertValidator string
|
||||
|
||||
func (v sshTestCertValidator) Valid(crt *ssh.Certificate) error {
|
||||
func (v sshTestCertValidator) Valid(crt *ssh.Certificate, opts provisioner.SSHOptions) error {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
|
@ -76,7 +80,7 @@ func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
|
|||
|
||||
type sshTestOptionsModifier string
|
||||
|
||||
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier {
|
||||
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertModifier {
|
||||
return sshTestCertModifier(string(m))
|
||||
}
|
||||
|
||||
|
@ -488,18 +492,18 @@ func TestAuthority_CheckSSHHost(t *testing.T) {
|
|||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"true", fields{true, nil}, args{context.TODO(), "foo.internal.com", ""}, true, false},
|
||||
{"false", fields{false, nil}, args{context.TODO(), "foo.internal.com", ""}, false, false},
|
||||
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true},
|
||||
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true},
|
||||
{"internal", fields{false, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true},
|
||||
{"internal", fields{true, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true},
|
||||
{"true", fields{true, nil}, args{context.Background(), "foo.internal.com", ""}, true, false},
|
||||
{"false", fields{false, nil}, args{context.Background(), "foo.internal.com", ""}, false, false},
|
||||
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
|
||||
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
|
||||
{"internal", fields{false, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
|
||||
{"internal", fields{true, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := testAuthority(t)
|
||||
a.db = &MockAuthDB{
|
||||
isSSHHost: func(_ string) (bool, error) {
|
||||
a.db = &db.MockAuthDB{
|
||||
MIsSSHHost: func(_ string) (bool, error) {
|
||||
return tt.fields.exists, tt.fields.err
|
||||
},
|
||||
}
|
||||
|
@ -640,6 +644,9 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
|
|||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
} else if err != nil {
|
||||
_, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want)
|
||||
|
@ -647,3 +654,266 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_GetSSHHosts(t *testing.T) {
|
||||
a := testAuthority(t)
|
||||
|
||||
type test struct {
|
||||
getHostsFunc func(*x509.Certificate) ([]sshutil.Host, error)
|
||||
auth *Authority
|
||||
cert *x509.Certificate
|
||||
cmp func(got []sshutil.Host)
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(t *testing.T) *test{
|
||||
"fail/getHostsFunc-fail": func(t *testing.T) *test {
|
||||
return &test{
|
||||
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
cert: &x509.Certificate{},
|
||||
err: errors.New("getSSHHosts: force"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
},
|
||||
"ok/getHostsFunc-defined": func(t *testing.T) *test {
|
||||
hosts := []sshutil.Host{
|
||||
{HostID: "1", Hostname: "foo"},
|
||||
{HostID: "2", Hostname: "bar"},
|
||||
}
|
||||
|
||||
return &test{
|
||||
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||
return hosts, nil
|
||||
},
|
||||
cert: &x509.Certificate{},
|
||||
cmp: func(got []sshutil.Host) {
|
||||
assert.Equals(t, got, hosts)
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/db-get-fail": func(t *testing.T) *test {
|
||||
return &test{
|
||||
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
|
||||
MGetSSHHostPrincipals: func() ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
})),
|
||||
cert: &x509.Certificate{},
|
||||
err: errors.New("getSSHHosts: force"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *test {
|
||||
return &test{
|
||||
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
|
||||
MGetSSHHostPrincipals: func() ([]string, error) {
|
||||
return []string{"foo", "bar"}, nil
|
||||
},
|
||||
})),
|
||||
cert: &x509.Certificate{},
|
||||
cmp: func(got []sshutil.Host) {
|
||||
assert.Equals(t, got, []sshutil.Host{
|
||||
{Hostname: "foo"},
|
||||
{Hostname: "bar"},
|
||||
})
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, genTestCase := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := genTestCase(t)
|
||||
|
||||
auth := tc.auth
|
||||
if auth == nil {
|
||||
auth = a
|
||||
}
|
||||
auth.sshGetHostsFunc = tc.getHostsFunc
|
||||
|
||||
hosts, err := auth.GetSSHHosts(tc.cert)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
tc.cmp(hosts)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_RekeySSH(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.FatalError(t, err)
|
||||
pub, err := ssh.NewPublicKey(key.Public())
|
||||
assert.FatalError(t, err)
|
||||
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.FatalError(t, err)
|
||||
signer, err := ssh.NewSignerFromKey(signKey)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userOptions := sshTestModifier{
|
||||
CertType: ssh.UserCert,
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
a := testAuthority(t)
|
||||
|
||||
type test struct {
|
||||
auth *Authority
|
||||
userSigner ssh.Signer
|
||||
hostSigner ssh.Signer
|
||||
cert *ssh.Certificate
|
||||
key ssh.PublicKey
|
||||
signOpts []provisioner.SignOption
|
||||
cmpResult func(old, n *ssh.Certificate)
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(t *testing.T) *test{
|
||||
"fail/opts-type": func(t *testing.T) *test {
|
||||
return &test{
|
||||
userSigner: signer,
|
||||
hostSigner: signer,
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{userOptions},
|
||||
err: errors.New("rekeySSH; invalid extra option type"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
},
|
||||
"fail/old-cert-validAfter": func(t *testing.T) *test {
|
||||
return &test{
|
||||
userSigner: signer,
|
||||
hostSigner: signer,
|
||||
cert: &ssh.Certificate{},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
err: errors.New("rekeySSH; cannot rekey certificate without validity period"),
|
||||
code: http.StatusBadRequest,
|
||||
}
|
||||
},
|
||||
"fail/old-cert-validBefore": func(t *testing.T) *test {
|
||||
return &test{
|
||||
userSigner: signer,
|
||||
hostSigner: signer,
|
||||
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
err: errors.New("rekeySSH; cannot rekey certificate without validity period"),
|
||||
code: http.StatusBadRequest,
|
||||
}
|
||||
},
|
||||
"fail/old-cert-no-user-key": func(t *testing.T) *test {
|
||||
return &test{
|
||||
userSigner: nil,
|
||||
hostSigner: signer,
|
||||
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
err: errors.New("rekeySSH; user certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
},
|
||||
"fail/old-cert-no-host-key": func(t *testing.T) *test {
|
||||
return &test{
|
||||
userSigner: signer,
|
||||
hostSigner: nil,
|
||||
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.HostCert},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
err: errors.New("rekeySSH; host certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
},
|
||||
"fail/unexpected-old-cert-type": func(t *testing.T) *test {
|
||||
return &test{
|
||||
userSigner: signer,
|
||||
hostSigner: signer,
|
||||
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
err: errors.New("rekeySSH; unexpected ssh certificate type: 0"),
|
||||
code: http.StatusBadRequest,
|
||||
}
|
||||
},
|
||||
"fail/db-store": func(t *testing.T) *test {
|
||||
return &test{
|
||||
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
|
||||
MStoreSSHCertificate: func(cert *ssh.Certificate) error {
|
||||
return errors.New("force")
|
||||
},
|
||||
})),
|
||||
userSigner: signer,
|
||||
hostSigner: nil,
|
||||
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
err: errors.New("rekeySSH; error storing certificate in db: force"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *test {
|
||||
va1 := now.Add(-24 * time.Hour)
|
||||
vb1 := now.Add(-23 * time.Hour)
|
||||
return &test{
|
||||
userSigner: signer,
|
||||
hostSigner: nil,
|
||||
cert: &ssh.Certificate{
|
||||
ValidAfter: uint64(va1.Unix()),
|
||||
ValidBefore: uint64(vb1.Unix()),
|
||||
CertType: ssh.UserCert,
|
||||
ValidPrincipals: []string{"foo", "bar"},
|
||||
KeyId: "foo",
|
||||
},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
cmpResult: func(old, n *ssh.Certificate) {
|
||||
assert.Equals(t, n.CertType, old.CertType)
|
||||
assert.Equals(t, n.ValidPrincipals, old.ValidPrincipals)
|
||||
assert.Equals(t, n.KeyId, old.KeyId)
|
||||
|
||||
assert.True(t, n.ValidAfter > uint64(now.Add(-5*time.Minute).Unix()))
|
||||
assert.True(t, n.ValidAfter < uint64(now.Add(5*time.Minute).Unix()))
|
||||
|
||||
l8r := now.Add(1 * time.Hour)
|
||||
assert.True(t, n.ValidBefore > uint64(l8r.Add(-5*time.Minute).Unix()))
|
||||
assert.True(t, n.ValidBefore < uint64(l8r.Add(5*time.Minute).Unix()))
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, genTestCase := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := genTestCase(t)
|
||||
|
||||
auth := tc.auth
|
||||
if auth == nil {
|
||||
auth = a
|
||||
}
|
||||
a.sshCAUserCertSignKey = tc.userSigner
|
||||
a.sshCAHostCertSignKey = tc.hostSigner
|
||||
|
||||
cert, err := auth.RekeySSH(tc.cert, tc.key, tc.signOpts...)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
tc.cmpResult(tc.cert, cert)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
1
authority/testdata/certs/ssh_host_ca_key.pub
vendored
Normal file
1
authority/testdata/certs/ssh_host_ca_key.pub
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY=
|
1
authority/testdata/certs/ssh_user_ca_key.pub
vendored
Normal file
1
authority/testdata/certs/ssh_user_ca_key.pub
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y=
|
5
authority/testdata/secrets/ssh_host_ca_key
vendored
Normal file
5
authority/testdata/secrets/ssh_host_ca_key
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49
|
||||
AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+
|
||||
MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg==
|
||||
-----END EC PRIVATE KEY-----
|
5
authority/testdata/secrets/ssh_user_ca_key
vendored
Normal file
5
authority/testdata/secrets/ssh_user_ca_key
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49
|
||||
AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h
|
||||
YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg==
|
||||
-----END EC PRIVATE KEY-----
|
148
authority/tls.go
148
authority/tls.go
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
|
@ -60,7 +61,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
|
|||
// Sign creates a signed certificate from a certificate signing request.
|
||||
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
var (
|
||||
errContext = apiCtx{"csr": csr, "signOptions": signOpts}
|
||||
opts = []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
|
||||
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
|
||||
certValidators = []provisioner.CertificateValidator{}
|
||||
issIdentity = a.intermediateIdentity
|
||||
|
@ -75,54 +76,52 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
|
|||
certValidators = append(certValidators, k)
|
||||
case provisioner.CertificateRequestValidator:
|
||||
if err := k.Valid(csr); err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
|
||||
}
|
||||
case provisioner.ProfileModifier:
|
||||
mods = append(mods, k.Option(signOpts))
|
||||
default:
|
||||
return nil, &apiError{errors.Errorf("sign: invalid extra option type %T", k),
|
||||
http.StatusInternalServerError, errContext}
|
||||
return nil, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]interface{}{k}, opts...)...)
|
||||
}
|
||||
}
|
||||
|
||||
if err := csr.CheckSignature(); err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "sign: invalid certificate request"),
|
||||
http.StatusBadRequest, errContext}
|
||||
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.Sign; invalid certificate request", opts...)
|
||||
}
|
||||
|
||||
leaf, err := x509util.NewLeafProfileWithCSR(csr, issIdentity.Crt, issIdentity.Key, mods...)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrapf(err, "sign"), http.StatusInternalServerError, errContext}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign", opts...)
|
||||
}
|
||||
|
||||
for _, v := range certValidators {
|
||||
if err := v.Valid(leaf.Subject()); err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
|
||||
if err := v.Valid(leaf.Subject(), signOpts); err != nil {
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
|
||||
}
|
||||
}
|
||||
|
||||
crtBytes, err := leaf.CreateCertificate()
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "sign: error creating new leaf certificate"),
|
||||
http.StatusInternalServerError, errContext}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.Sign; error creating new leaf certificate", opts...)
|
||||
}
|
||||
|
||||
serverCert, err := x509.ParseCertificate(crtBytes)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "sign: error parsing new leaf certificate"),
|
||||
http.StatusInternalServerError, errContext}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.Sign; error parsing new leaf certificate", opts...)
|
||||
}
|
||||
|
||||
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "sign: error parsing intermediate certificate"),
|
||||
http.StatusInternalServerError, errContext}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.Sign; error parsing intermediate certificate", opts...)
|
||||
}
|
||||
|
||||
if err = a.db.StoreCertificate(serverCert); err != nil {
|
||||
if err != db.ErrNotImplemented {
|
||||
return nil, &apiError{errors.Wrap(err, "sign: error storing certificate in db"),
|
||||
http.StatusInternalServerError, errContext}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.Sign; error storing certificate in db", opts...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,9 +131,11 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
|
|||
// Renew creates a new Certificate identical to the old certificate, except
|
||||
// with a validity window that begins 'now'.
|
||||
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
|
||||
|
||||
// Check step provisioner extensions
|
||||
if err := a.authorizeRenew(oldCert); err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew", opts...)
|
||||
}
|
||||
|
||||
// Issuer
|
||||
|
@ -161,16 +162,16 @@ func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error
|
|||
MaxPathLenZero: oldCert.MaxPathLenZero,
|
||||
OCSPServer: oldCert.OCSPServer,
|
||||
IssuingCertificateURL: oldCert.IssuingCertificateURL,
|
||||
PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical,
|
||||
PermittedEmailAddresses: oldCert.PermittedEmailAddresses,
|
||||
DNSNames: oldCert.DNSNames,
|
||||
EmailAddresses: oldCert.EmailAddresses,
|
||||
IPAddresses: oldCert.IPAddresses,
|
||||
URIs: oldCert.URIs,
|
||||
PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical,
|
||||
PermittedDNSDomains: oldCert.PermittedDNSDomains,
|
||||
ExcludedDNSDomains: oldCert.ExcludedDNSDomains,
|
||||
PermittedIPRanges: oldCert.PermittedIPRanges,
|
||||
ExcludedIPRanges: oldCert.ExcludedIPRanges,
|
||||
PermittedEmailAddresses: oldCert.PermittedEmailAddresses,
|
||||
ExcludedEmailAddresses: oldCert.ExcludedEmailAddresses,
|
||||
PermittedURIDomains: oldCert.PermittedURIDomains,
|
||||
ExcludedURIDomains: oldCert.ExcludedURIDomains,
|
||||
|
@ -190,29 +191,28 @@ func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error
|
|||
leaf, err := x509util.NewLeafProfileWithTemplate(newCert,
|
||||
issIdentity.Crt, issIdentity.Key)
|
||||
if err != nil {
|
||||
return nil, &apiError{err, http.StatusInternalServerError, apiCtx{}}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew", opts...)
|
||||
}
|
||||
crtBytes, err := leaf.CreateCertificate()
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"),
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.Renew; error renewing certificate from existing server certificate", opts...)
|
||||
}
|
||||
|
||||
serverCert, err := x509.ParseCertificate(crtBytes)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "error parsing new server certificate"),
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.Renew; error parsing new server certificate", opts...)
|
||||
}
|
||||
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"),
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.Renew; error parsing intermediate certificate", opts...)
|
||||
}
|
||||
|
||||
if err = a.db.StoreCertificate(serverCert); err != nil {
|
||||
if err != db.ErrNotImplemented {
|
||||
return nil, &apiError{errors.Wrap(err, "error storing certificate in db"),
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew; error storing certificate in db", opts...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -236,26 +236,26 @@ type RevokeOptions struct {
|
|||
// being renewed.
|
||||
//
|
||||
// TODO: Add OCSP and CRL support.
|
||||
func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error {
|
||||
errContext := apiCtx{
|
||||
"serialNumber": opts.Serial,
|
||||
"reasonCode": opts.ReasonCode,
|
||||
"reason": opts.Reason,
|
||||
"passiveOnly": opts.PassiveOnly,
|
||||
"mTLS": opts.MTLS,
|
||||
"context": string(provisioner.MethodFromContext(ctx)),
|
||||
func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error {
|
||||
opts := []interface{}{
|
||||
errs.WithKeyVal("serialNumber", revokeOpts.Serial),
|
||||
errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode),
|
||||
errs.WithKeyVal("reason", revokeOpts.Reason),
|
||||
errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly),
|
||||
errs.WithKeyVal("MTLS", revokeOpts.MTLS),
|
||||
errs.WithKeyVal("context", string(provisioner.MethodFromContext(ctx))),
|
||||
}
|
||||
if opts.MTLS {
|
||||
errContext["certificate"] = base64.StdEncoding.EncodeToString(opts.Crt.Raw)
|
||||
if revokeOpts.MTLS {
|
||||
opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw)))
|
||||
} else {
|
||||
errContext["ott"] = opts.OTT
|
||||
opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT))
|
||||
}
|
||||
|
||||
rci := &db.RevokedCertificateInfo{
|
||||
Serial: opts.Serial,
|
||||
ReasonCode: opts.ReasonCode,
|
||||
Reason: opts.Reason,
|
||||
MTLS: opts.MTLS,
|
||||
Serial: revokeOpts.Serial,
|
||||
ReasonCode: revokeOpts.ReasonCode,
|
||||
Reason: revokeOpts.Reason,
|
||||
MTLS: revokeOpts.MTLS,
|
||||
RevokedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
|
@ -264,48 +264,43 @@ func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error {
|
|||
err error
|
||||
)
|
||||
// If not mTLS then get the TokenID of the token.
|
||||
if !opts.MTLS {
|
||||
// Validate payload
|
||||
token, err := jose.ParseSigned(opts.OTT)
|
||||
if !revokeOpts.MTLS {
|
||||
token, err := jose.ParseSigned(revokeOpts.OTT)
|
||||
if err != nil {
|
||||
return &apiError{errors.Wrapf(err, "revoke: error parsing token"),
|
||||
http.StatusUnauthorized, errContext}
|
||||
return errs.Wrap(http.StatusUnauthorized, err,
|
||||
"authority.Revoke; error parsing token", opts...)
|
||||
}
|
||||
|
||||
// Get claims w/out verification. We should have already verified this token
|
||||
// earlier with a call to authorizeSSHRevoke.
|
||||
// Get claims w/out verification.
|
||||
var claims Claims
|
||||
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return &apiError{errors.Wrap(err, "revoke"), http.StatusUnauthorized, errContext}
|
||||
return errs.Wrap(http.StatusUnauthorized, err, "authority.Revoke", opts...)
|
||||
}
|
||||
|
||||
// This method will also validate the audiences for JWK provisioners.
|
||||
var ok bool
|
||||
p, ok = a.provisioners.LoadByToken(token, &claims.Claims)
|
||||
if !ok {
|
||||
return &apiError{
|
||||
errors.Errorf("revoke: provisioner not found"),
|
||||
http.StatusInternalServerError, errContext}
|
||||
return errs.InternalServer("authority.Revoke; provisioner not found", opts...)
|
||||
}
|
||||
rci.TokenID, err = p.GetTokenID(opts.OTT)
|
||||
rci.TokenID, err = p.GetTokenID(revokeOpts.OTT)
|
||||
if err != nil {
|
||||
return &apiError{errors.Wrap(err, "revoke: could not get ID for token"),
|
||||
http.StatusInternalServerError, errContext}
|
||||
return errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.Revoke; could not get ID for token")
|
||||
}
|
||||
errContext["tokenID"] = rci.TokenID
|
||||
opts = append(opts, errs.WithKeyVal("tokenID", rci.TokenID))
|
||||
} else {
|
||||
// Load the Certificate provisioner if one exists.
|
||||
p, err = a.LoadProvisionerByCertificate(opts.Crt)
|
||||
p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt)
|
||||
if err != nil {
|
||||
return &apiError{
|
||||
errors.Wrap(err, "revoke: unable to load certificate provisioner"),
|
||||
http.StatusUnauthorized, errContext}
|
||||
return errs.Wrap(http.StatusUnauthorized, err,
|
||||
"authority.Revoke: unable to load certificate provisioner", opts...)
|
||||
}
|
||||
}
|
||||
rci.ProvisionerID = p.GetID()
|
||||
errContext["provisionerID"] = rci.ProvisionerID
|
||||
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
|
||||
|
||||
if provisioner.MethodFromContext(ctx) == provisioner.RevokeSSHMethod {
|
||||
if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod {
|
||||
err = a.db.RevokeSSH(rci)
|
||||
} else { // default to revoke x509
|
||||
err = a.db.Revoke(rci)
|
||||
|
@ -314,13 +309,12 @@ func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error {
|
|||
case nil:
|
||||
return nil
|
||||
case db.ErrNotImplemented:
|
||||
return &apiError{errors.New("revoke: no persistence layer configured"),
|
||||
http.StatusNotImplemented, errContext}
|
||||
return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...)
|
||||
case db.ErrAlreadyExists:
|
||||
return &apiError{errors.Errorf("revoke: certificate with serial number %s has already been revoked", rci.Serial),
|
||||
http.StatusBadRequest, errContext}
|
||||
return errs.BadRequest("authority.Revoke; certificate with serial "+
|
||||
"number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...)
|
||||
default:
|
||||
return &apiError{err, http.StatusInternalServerError, errContext}
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -330,17 +324,17 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
|
|||
a.intermediateIdentity.Crt, a.intermediateIdentity.Key,
|
||||
x509util.WithHosts(strings.Join(a.config.DNSNames, ",")))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
|
||||
}
|
||||
|
||||
crtBytes, err := profile.CreateCertificate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
|
||||
}
|
||||
|
||||
keyPEM, err := pemutil.Serialize(profile.SubjectPrivateKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
|
||||
}
|
||||
|
||||
crtPEM := pem.EncodeToMemory(&pem.Block{
|
||||
|
@ -352,19 +346,21 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
|
|||
// to a tls.Certificate.
|
||||
intermediatePEM, err := pemutil.Serialize(a.intermediateIdentity.Crt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate")
|
||||
}
|
||||
tlsCrt, err := tls.X509KeyPair(append(crtPEM,
|
||||
pem.EncodeToMemory(intermediatePEM)...),
|
||||
pem.EncodeToMemory(keyPEM))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error creating tls certificate")
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.GetTLSCertificate; error creating tls certificate")
|
||||
}
|
||||
|
||||
// Get the 'leaf' certificate and set the attribute accordingly.
|
||||
leaf, err := x509.ParseCertificate(tlsCrt.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing tls certificate")
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.GetTLSCertificate; error parsing tls certificate")
|
||||
}
|
||||
tlsCrt.Leaf = leaf
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -19,6 +18,7 @@ import (
|
|||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
|
@ -77,7 +77,7 @@ func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateReques
|
|||
return csr
|
||||
}
|
||||
|
||||
func TestSign(t *testing.T) {
|
||||
func TestAuthority_Sign(t *testing.T) {
|
||||
pub, priv, err := keys.GenerateDefaultKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
@ -102,7 +102,7 @@ func TestSign(t *testing.T) {
|
|||
p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK)
|
||||
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||
assert.FatalError(t, err)
|
||||
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
|
||||
token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
extraOpts, err := a.Authorize(ctx, token)
|
||||
|
@ -113,7 +113,8 @@ func TestSign(t *testing.T) {
|
|||
csr *x509.CertificateRequest
|
||||
signOpts provisioner.Options
|
||||
extraOpts []provisioner.SignOption
|
||||
err *apiError
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(*testing.T) *signTest{
|
||||
"fail invalid signature": func(t *testing.T) *signTest {
|
||||
|
@ -124,10 +125,8 @@ func TestSign(t *testing.T) {
|
|||
csr: csr,
|
||||
extraOpts: extraOpts,
|
||||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: invalid certificate request"),
|
||||
http.StatusBadRequest,
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
err: errors.New("authority.Sign; invalid certificate request"),
|
||||
code: http.StatusBadRequest,
|
||||
}
|
||||
},
|
||||
"fail invalid extra option": func(t *testing.T) *signTest {
|
||||
|
@ -138,10 +137,8 @@ func TestSign(t *testing.T) {
|
|||
csr: csr,
|
||||
extraOpts: append(extraOpts, "42"),
|
||||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: invalid extra option type string"),
|
||||
http.StatusInternalServerError,
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
err: errors.New("authority.Sign; invalid extra option type string"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
},
|
||||
"fail merge default ASN1DN": func(t *testing.T) *signTest {
|
||||
|
@ -153,10 +150,8 @@ func TestSign(t *testing.T) {
|
|||
csr: csr,
|
||||
extraOpts: extraOpts,
|
||||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"),
|
||||
http.StatusInternalServerError,
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
err: errors.New("authority.Sign: default ASN1DN template cannot be nil"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
},
|
||||
"fail create cert": func(t *testing.T) *signTest {
|
||||
|
@ -168,10 +163,8 @@ func TestSign(t *testing.T) {
|
|||
csr: csr,
|
||||
extraOpts: extraOpts,
|
||||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: error creating new leaf certificate"),
|
||||
http.StatusInternalServerError,
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
err: errors.New("authority.Sign; error creating new leaf certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
},
|
||||
"fail provisioner duration claim": func(t *testing.T) *signTest {
|
||||
|
@ -185,10 +178,8 @@ func TestSign(t *testing.T) {
|
|||
csr: csr,
|
||||
extraOpts: extraOpts,
|
||||
signOpts: _signOpts,
|
||||
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"),
|
||||
http.StatusUnauthorized,
|
||||
apiCtx{"csr": csr, "signOptions": _signOpts},
|
||||
},
|
||||
err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
"fail validate sans when adding common name not in claims": func(t *testing.T) *signTest {
|
||||
|
@ -200,10 +191,8 @@ func TestSign(t *testing.T) {
|
|||
csr: csr,
|
||||
extraOpts: extraOpts,
|
||||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
||||
http.StatusUnauthorized,
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
err: errors.New("authority.Sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
"fail rsa key too short": func(t *testing.T) *signTest {
|
||||
|
@ -228,20 +217,16 @@ ZYtQ9Ot36qc=
|
|||
csr: csr,
|
||||
extraOpts: extraOpts,
|
||||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
|
||||
http.StatusUnauthorized,
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
err: errors.New("authority.Sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
"fail store cert in db": func(t *testing.T) *signTest {
|
||||
csr := getCSR(t, priv)
|
||||
_a := testAuthority(t)
|
||||
_a.db = &MockAuthDB{
|
||||
storeCertificate: func(crt *x509.Certificate) error {
|
||||
return &apiError{errors.New("force"),
|
||||
http.StatusInternalServerError,
|
||||
apiCtx{"csr": csr, "signOptions": signOpts}}
|
||||
_a.db = &db.MockAuthDB{
|
||||
MStoreCertificate: func(crt *x509.Certificate) error {
|
||||
return errors.New("force")
|
||||
},
|
||||
}
|
||||
return &signTest{
|
||||
|
@ -249,17 +234,15 @@ ZYtQ9Ot36qc=
|
|||
csr: csr,
|
||||
extraOpts: extraOpts,
|
||||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: error storing certificate in db: force"),
|
||||
http.StatusInternalServerError,
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
err: errors.New("authority.Sign; error storing certificate in db: force"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *signTest {
|
||||
csr := getCSR(t, priv)
|
||||
_a := testAuthority(t)
|
||||
_a.db = &MockAuthDB{
|
||||
storeCertificate: func(crt *x509.Certificate) error {
|
||||
_a.db = &db.MockAuthDB{
|
||||
MStoreCertificate: func(crt *x509.Certificate) error {
|
||||
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
|
||||
return nil
|
||||
},
|
||||
|
@ -279,15 +262,17 @@ ZYtQ9Ot36qc=
|
|||
|
||||
certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch v := err.(type) {
|
||||
case *apiError:
|
||||
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.code, tc.err.code)
|
||||
assert.Equals(t, v.context, tc.err.context)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
}
|
||||
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
||||
assert.Nil(t, certChain)
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
|
||||
ctxErr, ok := err.(*errs.Error)
|
||||
assert.Fatal(t, ok, "error is not of type *errs.Error")
|
||||
assert.Equals(t, ctxErr.Details["csr"], tc.csr)
|
||||
assert.Equals(t, ctxErr.Details["signOptions"], tc.signOpts)
|
||||
}
|
||||
} else {
|
||||
leaf := certChain[0]
|
||||
|
@ -346,7 +331,7 @@ ZYtQ9Ot36qc=
|
|||
}
|
||||
}
|
||||
|
||||
func TestRenew(t *testing.T) {
|
||||
func TestAuthority_Renew(t *testing.T) {
|
||||
pub, _, err := keys.GenerateDefaultKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
@ -375,9 +360,9 @@ func TestRenew(t *testing.T) {
|
|||
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
|
||||
withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID))
|
||||
assert.FatalError(t, err)
|
||||
crtBytes, err := leaf.CreateCertificate()
|
||||
certBytes, err := leaf.CreateCertificate()
|
||||
assert.FatalError(t, err)
|
||||
crt, err := x509.ParseCertificate(crtBytes)
|
||||
cert, err := x509.ParseCertificate(certBytes)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
leafNoRenew, err := x509util.NewLeafProfile("norenew", a.intermediateIdentity.Crt,
|
||||
|
@ -388,15 +373,16 @@ func TestRenew(t *testing.T) {
|
|||
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID),
|
||||
)
|
||||
assert.FatalError(t, err)
|
||||
crtBytesNoRenew, err := leafNoRenew.CreateCertificate()
|
||||
certBytesNoRenew, err := leafNoRenew.CreateCertificate()
|
||||
assert.FatalError(t, err)
|
||||
crtNoRenew, err := x509.ParseCertificate(crtBytesNoRenew)
|
||||
certNoRenew, err := x509.ParseCertificate(certBytesNoRenew)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type renewTest struct {
|
||||
auth *Authority
|
||||
crt *x509.Certificate
|
||||
err *apiError
|
||||
cert *x509.Certificate
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func() (*renewTest, error){
|
||||
"fail-create-cert": func() (*renewTest, error) {
|
||||
|
@ -404,25 +390,22 @@ func TestRenew(t *testing.T) {
|
|||
_a.intermediateIdentity.Key = nil
|
||||
return &renewTest{
|
||||
auth: _a,
|
||||
crt: crt,
|
||||
err: &apiError{errors.New("error renewing certificate from existing server certificate"),
|
||||
http.StatusInternalServerError, apiCtx{}},
|
||||
cert: cert,
|
||||
err: errors.New("authority.Renew; error renewing certificate from existing server certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}, nil
|
||||
},
|
||||
"fail-unauthorized": func() (*renewTest, error) {
|
||||
ctx := map[string]interface{}{
|
||||
"serialNumber": crtNoRenew.SerialNumber.String(),
|
||||
}
|
||||
return &renewTest{
|
||||
crt: crtNoRenew,
|
||||
err: &apiError{errors.New("renew: renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
|
||||
http.StatusUnauthorized, ctx},
|
||||
cert: certNoRenew,
|
||||
err: errors.New("authority.Renew: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
"success": func() (*renewTest, error) {
|
||||
return &renewTest{
|
||||
auth: a,
|
||||
crt: crt,
|
||||
cert: cert,
|
||||
}, nil
|
||||
},
|
||||
"success-new-intermediate": func() (*renewTest, error) {
|
||||
|
@ -430,23 +413,23 @@ func TestRenew(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
newRootBytes, err := newRootProfile.CreateCertificate()
|
||||
assert.FatalError(t, err)
|
||||
newRootCrt, err := x509.ParseCertificate(newRootBytes)
|
||||
newRootCert, err := x509.ParseCertificate(newRootBytes)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
newIntermediateProfile, err := x509util.NewIntermediateProfile("new-intermediate",
|
||||
newRootCrt, newRootProfile.SubjectPrivateKey())
|
||||
newRootCert, newRootProfile.SubjectPrivateKey())
|
||||
assert.FatalError(t, err)
|
||||
newIntermediateBytes, err := newIntermediateProfile.CreateCertificate()
|
||||
assert.FatalError(t, err)
|
||||
newIntermediateCrt, err := x509.ParseCertificate(newIntermediateBytes)
|
||||
newIntermediateCert, err := x509.ParseCertificate(newIntermediateBytes)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_a := testAuthority(t)
|
||||
_a.intermediateIdentity.Key = newIntermediateProfile.SubjectPrivateKey()
|
||||
_a.intermediateIdentity.Crt = newIntermediateCrt
|
||||
_a.intermediateIdentity.Crt = newIntermediateCert
|
||||
return &renewTest{
|
||||
auth: _a,
|
||||
crt: crt,
|
||||
cert: cert,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
@ -458,32 +441,33 @@ func TestRenew(t *testing.T) {
|
|||
|
||||
var certChain []*x509.Certificate
|
||||
if tc.auth != nil {
|
||||
certChain, err = tc.auth.Renew(tc.crt)
|
||||
certChain, err = tc.auth.Renew(tc.cert)
|
||||
} else {
|
||||
certChain, err = a.Renew(tc.crt)
|
||||
certChain, err = a.Renew(tc.cert)
|
||||
}
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch v := err.(type) {
|
||||
case *apiError:
|
||||
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.code, tc.err.code)
|
||||
assert.Equals(t, v.context, tc.err.context)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
}
|
||||
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
||||
assert.Nil(t, certChain)
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
|
||||
ctxErr, ok := err.(*errs.Error)
|
||||
assert.Fatal(t, ok, "error is not of type *errs.Error")
|
||||
assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String())
|
||||
}
|
||||
} else {
|
||||
leaf := certChain[0]
|
||||
intermediate := certChain[1]
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.crt.NotAfter.Sub(crt.NotBefore))
|
||||
assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore))
|
||||
|
||||
assert.True(t, leaf.NotBefore.After(now.Add(-time.Minute)))
|
||||
assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute)))
|
||||
assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute)))
|
||||
|
||||
expiry := now.Add(time.Minute * 7)
|
||||
assert.True(t, leaf.NotAfter.After(expiry.Add(-time.Minute)))
|
||||
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
|
||||
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
|
||||
|
||||
tmplt := a.config.AuthorityConfig.Template
|
||||
|
@ -513,7 +497,7 @@ func TestRenew(t *testing.T) {
|
|||
if a.intermediateIdentity.Crt.SerialNumber == tc.auth.intermediateIdentity.Crt.SerialNumber {
|
||||
assert.Equals(t, leaf.AuthorityKeyId, a.intermediateIdentity.Crt.SubjectKeyId)
|
||||
// Compare extensions: they can be in a different order
|
||||
for _, ext1 := range tc.crt.Extensions {
|
||||
for _, ext1 := range tc.cert.Extensions {
|
||||
found := false
|
||||
for _, ext2 := range leaf.Extensions {
|
||||
if reflect.DeepEqual(ext1, ext2) {
|
||||
|
@ -529,7 +513,7 @@ func TestRenew(t *testing.T) {
|
|||
// We did change the intermediate before renewing.
|
||||
assert.Equals(t, leaf.AuthorityKeyId, tc.auth.intermediateIdentity.Crt.SubjectKeyId)
|
||||
// Compare extensions: they can be in a different order
|
||||
for _, ext1 := range tc.crt.Extensions {
|
||||
for _, ext1 := range tc.cert.Extensions {
|
||||
// The authority key id extension should be different b/c the intermediates are different.
|
||||
if ext1.Id.Equal(oidAuthorityKeyIdentifier) {
|
||||
for _, ext2 := range leaf.Extensions {
|
||||
|
@ -560,7 +544,7 @@ func TestRenew(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetTLSOptions(t *testing.T) {
|
||||
func TestAuthority_GetTLSOptions(t *testing.T) {
|
||||
type renewTest struct {
|
||||
auth *Authority
|
||||
opts *tlsutil.TLSOptions
|
||||
|
@ -596,21 +580,12 @@ func TestGetTLSOptions(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRevoke(t *testing.T) {
|
||||
func TestAuthority_Revoke(t *testing.T) {
|
||||
reasonCode := 2
|
||||
reason := "bob was let go"
|
||||
validIssuer := "step-cli"
|
||||
validAudience := []string{"https://test.ca.smallstep.com/revoke"}
|
||||
validAudience := testAudiences.Revoke
|
||||
now := time.Now().UTC()
|
||||
getCtx := func() map[string]interface{} {
|
||||
return apiCtx{
|
||||
"serialNumber": "sn",
|
||||
"reasonCode": reasonCode,
|
||||
"reason": reason,
|
||||
"mTLS": false,
|
||||
"passiveOnly": false,
|
||||
}
|
||||
}
|
||||
|
||||
jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||
assert.FatalError(t, err)
|
||||
|
@ -619,30 +594,30 @@ func TestRevoke(t *testing.T) {
|
|||
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
|
||||
assert.FatalError(t, err)
|
||||
|
||||
a := testAuthority(t)
|
||||
|
||||
type test struct {
|
||||
a *Authority
|
||||
auth *Authority
|
||||
opts *RevokeOptions
|
||||
err *apiError
|
||||
err error
|
||||
code int
|
||||
checkErrDetails func(err *errs.Error)
|
||||
}
|
||||
tests := map[string]func() test{
|
||||
"error/token/authorizeRevoke error": func() test {
|
||||
a := testAuthority(t)
|
||||
ctx := getCtx()
|
||||
ctx["ott"] = "foo"
|
||||
"fail/token/authorizeRevoke error": func() test {
|
||||
return test{
|
||||
a: a,
|
||||
auth: a,
|
||||
opts: &RevokeOptions{
|
||||
OTT: "foo",
|
||||
Serial: "sn",
|
||||
ReasonCode: reasonCode,
|
||||
Reason: reason,
|
||||
},
|
||||
err: &apiError{errors.New("revoke: authorizeRevoke: authorizeToken: error parsing token"),
|
||||
http.StatusUnauthorized, ctx},
|
||||
err: errors.New("authority.Revoke; error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
"error/nil-db": func() test {
|
||||
a := testAuthority(t)
|
||||
"fail/nil-db": func() test {
|
||||
cl := jwt.Claims{
|
||||
Subject: "sn",
|
||||
Issuer: validIssuer,
|
||||
|
@ -654,30 +629,30 @@ func TestRevoke(t *testing.T) {
|
|||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
ctx := getCtx()
|
||||
ctx["ott"] = raw
|
||||
ctx["tokenID"] = "44"
|
||||
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
|
||||
return test{
|
||||
a: a,
|
||||
auth: a,
|
||||
opts: &RevokeOptions{
|
||||
Serial: "sn",
|
||||
ReasonCode: reasonCode,
|
||||
Reason: reason,
|
||||
OTT: raw,
|
||||
},
|
||||
err: &apiError{errors.New("revoke: no persistence layer configured"),
|
||||
http.StatusNotImplemented, ctx},
|
||||
err: errors.New("authority.Revoke; no persistence layer configured"),
|
||||
code: http.StatusNotImplemented,
|
||||
checkErrDetails: func(err *errs.Error) {
|
||||
assert.Equals(t, err.Details["token"], raw)
|
||||
assert.Equals(t, err.Details["tokenID"], "44")
|
||||
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
|
||||
},
|
||||
}
|
||||
},
|
||||
"error/db-revoke": func() test {
|
||||
a := testAuthority(t)
|
||||
a.db = &MockAuthDB{
|
||||
useToken: func(id, tok string) (bool, error) {
|
||||
"fail/db-revoke": func() test {
|
||||
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
|
||||
MUseToken: func(id, tok string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
err: errors.New("force"),
|
||||
}
|
||||
Err: errors.New("force"),
|
||||
}))
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "sn",
|
||||
|
@ -690,30 +665,30 @@ func TestRevoke(t *testing.T) {
|
|||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
ctx := getCtx()
|
||||
ctx["ott"] = raw
|
||||
ctx["tokenID"] = "44"
|
||||
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
|
||||
return test{
|
||||
a: a,
|
||||
auth: _a,
|
||||
opts: &RevokeOptions{
|
||||
Serial: "sn",
|
||||
ReasonCode: reasonCode,
|
||||
Reason: reason,
|
||||
OTT: raw,
|
||||
},
|
||||
err: &apiError{errors.New("force"),
|
||||
http.StatusInternalServerError, ctx},
|
||||
err: errors.New("authority.Revoke: force"),
|
||||
code: http.StatusInternalServerError,
|
||||
checkErrDetails: func(err *errs.Error) {
|
||||
assert.Equals(t, err.Details["token"], raw)
|
||||
assert.Equals(t, err.Details["tokenID"], "44")
|
||||
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
|
||||
},
|
||||
}
|
||||
},
|
||||
"error/already-revoked": func() test {
|
||||
a := testAuthority(t)
|
||||
a.db = &MockAuthDB{
|
||||
useToken: func(id, tok string) (bool, error) {
|
||||
"fail/already-revoked": func() test {
|
||||
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
|
||||
MUseToken: func(id, tok string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
err: db.ErrAlreadyExists,
|
||||
}
|
||||
Err: db.ErrAlreadyExists,
|
||||
}))
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "sn",
|
||||
|
@ -726,29 +701,29 @@ func TestRevoke(t *testing.T) {
|
|||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
ctx := getCtx()
|
||||
ctx["ott"] = raw
|
||||
ctx["tokenID"] = "44"
|
||||
ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc"
|
||||
return test{
|
||||
a: a,
|
||||
auth: _a,
|
||||
opts: &RevokeOptions{
|
||||
Serial: "sn",
|
||||
ReasonCode: reasonCode,
|
||||
Reason: reason,
|
||||
OTT: raw,
|
||||
},
|
||||
err: &apiError{errors.New("revoke: certificate with serial number sn has already been revoked"),
|
||||
http.StatusBadRequest, ctx},
|
||||
err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"),
|
||||
code: http.StatusBadRequest,
|
||||
checkErrDetails: func(err *errs.Error) {
|
||||
assert.Equals(t, err.Details["token"], raw)
|
||||
assert.Equals(t, err.Details["tokenID"], "44")
|
||||
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/token": func() test {
|
||||
a := testAuthority(t)
|
||||
a.db = &MockAuthDB{
|
||||
useToken: func(id, tok string) (bool, error) {
|
||||
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
|
||||
MUseToken: func(id, tok string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
}))
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "sn",
|
||||
|
@ -761,7 +736,7 @@ func TestRevoke(t *testing.T) {
|
|||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
a: a,
|
||||
auth: _a,
|
||||
opts: &RevokeOptions{
|
||||
Serial: "sn",
|
||||
ReasonCode: reasonCode,
|
||||
|
@ -770,39 +745,14 @@ func TestRevoke(t *testing.T) {
|
|||
},
|
||||
}
|
||||
},
|
||||
"error/mTLS/authorizeRevoke": func() test {
|
||||
a := testAuthority(t)
|
||||
a.db = &MockAuthDB{}
|
||||
|
||||
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
ctx := getCtx()
|
||||
ctx["certificate"] = base64.StdEncoding.EncodeToString(crt.Raw)
|
||||
ctx["mTLS"] = true
|
||||
|
||||
return test{
|
||||
a: a,
|
||||
opts: &RevokeOptions{
|
||||
Crt: crt,
|
||||
Serial: "sn",
|
||||
ReasonCode: reasonCode,
|
||||
Reason: reason,
|
||||
MTLS: true,
|
||||
},
|
||||
err: &apiError{errors.New("revoke: authorizeRevoke: serial number in certificate different than body"),
|
||||
http.StatusUnauthorized, ctx},
|
||||
}
|
||||
},
|
||||
"ok/mTLS": func() test {
|
||||
a := testAuthority(t)
|
||||
a.db = &MockAuthDB{}
|
||||
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{}))
|
||||
|
||||
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
a: a,
|
||||
auth: _a,
|
||||
opts: &RevokeOptions{
|
||||
Crt: crt,
|
||||
Serial: "102012593071130646873265215610956555026",
|
||||
|
@ -816,15 +766,24 @@ func TestRevoke(t *testing.T) {
|
|||
for name, f := range tests {
|
||||
tc := f()
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.a.Revoke(context.TODO(), tc.opts); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch v := err.(type) {
|
||||
case *apiError:
|
||||
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.code, tc.err.code)
|
||||
assert.Equals(t, v.context, tc.err.context)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
|
||||
if err := tc.auth.Revoke(ctx, tc.opts); err != nil {
|
||||
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
|
||||
ctxErr, ok := err.(*errs.Error)
|
||||
assert.Fatal(t, ok, "error is not of type *errs.Error")
|
||||
assert.Equals(t, ctxErr.Details["serialNumber"], tc.opts.Serial)
|
||||
assert.Equals(t, ctxErr.Details["reasonCode"], tc.opts.ReasonCode)
|
||||
assert.Equals(t, ctxErr.Details["reason"], tc.opts.Reason)
|
||||
assert.Equals(t, ctxErr.Details["MTLS"], tc.opts.MTLS)
|
||||
assert.Equals(t, ctxErr.Details["context"], string(provisioner.RevokeMethod))
|
||||
|
||||
if tc.checkErrDetails != nil {
|
||||
tc.checkErrDetails(ctxErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
|
@ -102,7 +103,7 @@ func TestCASign(t *testing.T) {
|
|||
ca: ca,
|
||||
body: "invalid json",
|
||||
status: http.StatusBadRequest,
|
||||
errMsg: "Bad Request",
|
||||
errMsg: errs.BadRequestDefaultMsg,
|
||||
}
|
||||
},
|
||||
"fail invalid-csr-sig": func(t *testing.T) *signTest {
|
||||
|
@ -140,7 +141,7 @@ ZEp7knvU2psWRw==
|
|||
ca: ca,
|
||||
body: string(body),
|
||||
status: http.StatusBadRequest,
|
||||
errMsg: "Bad Request",
|
||||
errMsg: errs.BadRequestDefaultMsg,
|
||||
}
|
||||
},
|
||||
"fail unauthorized-ott": func(t *testing.T) *signTest {
|
||||
|
@ -155,7 +156,7 @@ ZEp7knvU2psWRw==
|
|||
ca: ca,
|
||||
body: string(body),
|
||||
status: http.StatusUnauthorized,
|
||||
errMsg: "Unauthorized",
|
||||
errMsg: errs.UnauthorizedDefaultMsg,
|
||||
}
|
||||
},
|
||||
"fail commonname-claim": func(t *testing.T) *signTest {
|
||||
|
@ -188,7 +189,7 @@ ZEp7knvU2psWRw==
|
|||
ca: ca,
|
||||
body: string(body),
|
||||
status: http.StatusUnauthorized,
|
||||
errMsg: "Unauthorized",
|
||||
errMsg: errs.UnauthorizedDefaultMsg,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *signTest {
|
||||
|
@ -392,7 +393,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
|
|||
ca: ca,
|
||||
kid: "foo",
|
||||
status: http.StatusNotFound,
|
||||
errMsg: "Not Found",
|
||||
errMsg: errs.NotFoundDefaultMsg,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *ekt {
|
||||
|
@ -455,7 +456,7 @@ func TestCARoot(t *testing.T) {
|
|||
ca: ca,
|
||||
sha: "foo",
|
||||
status: http.StatusNotFound,
|
||||
errMsg: "Not Found",
|
||||
errMsg: errs.NotFoundDefaultMsg,
|
||||
}
|
||||
},
|
||||
"success": func(t *testing.T) *rootTest {
|
||||
|
@ -575,7 +576,7 @@ func TestCARenew(t *testing.T) {
|
|||
ca: ca,
|
||||
tlsConnState: nil,
|
||||
status: http.StatusBadRequest,
|
||||
errMsg: "Bad Request",
|
||||
errMsg: errs.BadRequestDefaultMsg,
|
||||
}
|
||||
},
|
||||
"request-missing-peer-certificate": func(t *testing.T) *renewTest {
|
||||
|
@ -583,7 +584,7 @@ func TestCARenew(t *testing.T) {
|
|||
ca: ca,
|
||||
tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}},
|
||||
status: http.StatusBadRequest,
|
||||
errMsg: "Bad Request",
|
||||
errMsg: errs.BadRequestDefaultMsg,
|
||||
}
|
||||
},
|
||||
"success": func(t *testing.T) *renewTest {
|
||||
|
|
38
ca/client.go
38
ca/client.go
|
@ -486,7 +486,7 @@ func (c *Client) Version() (*api.VersionResponse, error) {
|
|||
retry:
|
||||
resp, err := c.client.Get(u.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; client GET %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if !retried && c.retryOnError(resp) {
|
||||
|
@ -497,7 +497,7 @@ retry:
|
|||
}
|
||||
var version api.VersionResponse
|
||||
if err := readJSON(resp.Body, &version); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; error reading %s", u)
|
||||
}
|
||||
return &version, nil
|
||||
}
|
||||
|
@ -510,7 +510,7 @@ func (c *Client) Health() (*api.HealthResponse, error) {
|
|||
retry:
|
||||
resp, err := c.client.Get(u.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; client GET %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if !retried && c.retryOnError(resp) {
|
||||
|
@ -521,7 +521,7 @@ retry:
|
|||
}
|
||||
var health api.HealthResponse
|
||||
if err := readJSON(resp.Body, &health); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; error reading %s", u)
|
||||
}
|
||||
return &health, nil
|
||||
}
|
||||
|
@ -537,7 +537,7 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
|||
retry:
|
||||
resp, err := newInsecureClient().Get(u.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; client GET %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if !retried && c.retryOnError(resp) {
|
||||
|
@ -548,12 +548,12 @@ retry:
|
|||
}
|
||||
var root api.RootResponse
|
||||
if err := readJSON(resp.Body, &root); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; error reading %s", u)
|
||||
}
|
||||
// verify the sha256
|
||||
sum := sha256.Sum256(root.RootPEM.Raw)
|
||||
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) {
|
||||
return nil, errors.New("root certificate SHA256 fingerprint do not match")
|
||||
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
|
||||
}
|
||||
return &root, nil
|
||||
}
|
||||
|
@ -564,13 +564,13 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
|
|||
var retried bool
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshaling request")
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "client.Sign; error marshaling request")
|
||||
}
|
||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"})
|
||||
retry:
|
||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; client POST %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if !retried && c.retryOnError(resp) {
|
||||
|
@ -581,7 +581,7 @@ retry:
|
|||
}
|
||||
var sign api.SignResponse
|
||||
if err := readJSON(resp.Body, &sign); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; error reading %s", u)
|
||||
}
|
||||
// Add tls.ConnectionState:
|
||||
// We'll extract the root certificate from the verified chains
|
||||
|
@ -598,7 +598,7 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
|||
retry:
|
||||
resp, err := client.Post(u.String(), "application/json", http.NoBody)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; client POST %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if !retried && c.retryOnError(resp) {
|
||||
|
@ -609,7 +609,7 @@ retry:
|
|||
}
|
||||
var sign api.SignResponse
|
||||
if err := readJSON(resp.Body, &sign); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; error reading %s", u)
|
||||
}
|
||||
return &sign, nil
|
||||
}
|
||||
|
@ -961,8 +961,8 @@ func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrin
|
|||
retry:
|
||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", u,
|
||||
errs.WithMessage("Failed to perform POST request to %s", u))
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed",
|
||||
[]interface{}{u, errs.WithMessage("Failed to perform POST request to %s", u)}...)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if !retried && c.retryOnError(resp) {
|
||||
|
@ -974,8 +974,8 @@ retry:
|
|||
}
|
||||
var check api.SSHCheckPrincipalResponse
|
||||
if err := readJSON(resp.Body, &check); err != nil {
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", u,
|
||||
errs.WithMessage("Failed to parse response from /ssh/check-host endpoint"))
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response",
|
||||
[]interface{}{u, errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")})
|
||||
}
|
||||
return &check, nil
|
||||
}
|
||||
|
@ -1008,13 +1008,13 @@ func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse
|
|||
var retried bool
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshaling request")
|
||||
return nil, errors.Wrap(err, "client.SSHBastion; error marshaling request")
|
||||
}
|
||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"})
|
||||
retry:
|
||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||
return nil, errors.Wrapf(err, "client.SSHBastion; client POST %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if !retried && c.retryOnError(resp) {
|
||||
|
@ -1025,7 +1025,7 @@ retry:
|
|||
}
|
||||
var bastion api.SSHBastionResponse
|
||||
if err := readJSON(resp.Body, &bastion); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
return nil, errors.Wrapf(err, "client.SSHBastion; error reading %s", u)
|
||||
}
|
||||
return &bastion, nil
|
||||
}
|
||||
|
|
|
@ -16,12 +16,12 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
@ -154,18 +154,17 @@ func equalJSON(t *testing.T, a interface{}, b interface{}) bool {
|
|||
|
||||
func TestClient_Version(t *testing.T) {
|
||||
ok := &api.VersionResponse{Version: "test"}
|
||||
internal := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
|
||||
notFound := errs.NotFound(fmt.Errorf("Not Found"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
expectedErr error
|
||||
}{
|
||||
{"ok", ok, 200, false},
|
||||
{"500", internal, 500, true},
|
||||
{"404", notFound, 404, true},
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"500", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
|
||||
{"404", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -185,7 +184,6 @@ func TestClient_Version(t *testing.T) {
|
|||
|
||||
got, err := c.Version()
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Printf("%+v", err)
|
||||
t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
@ -195,9 +193,7 @@ func TestClient_Version(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.Version() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Version() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Version() = %v, want %v", got, tt.response)
|
||||
|
@ -209,16 +205,16 @@ func TestClient_Version(t *testing.T) {
|
|||
|
||||
func TestClient_Health(t *testing.T) {
|
||||
ok := &api.HealthResponse{Status: "ok"}
|
||||
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
expectedErr error
|
||||
}{
|
||||
{"ok", ok, 200, false},
|
||||
{"not ok", nok, 500, true},
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"not ok", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -248,9 +244,7 @@ func TestClient_Health(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.Health() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Health() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Health() = %v, want %v", got, tt.response)
|
||||
|
@ -264,7 +258,6 @@ func TestClient_Root(t *testing.T) {
|
|||
ok := &api.RootResponse{
|
||||
RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
||||
}
|
||||
notFound := errs.NotFound(fmt.Errorf("Not Found"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -272,9 +265,10 @@ func TestClient_Root(t *testing.T) {
|
|||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
expectedErr error
|
||||
}{
|
||||
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false},
|
||||
{"not found", "invalid", notFound, 404, true},
|
||||
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil},
|
||||
{"not found", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -307,9 +301,7 @@ func TestClient_Root(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.Root() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Root() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Root() = %v, want %v", got, tt.response)
|
||||
|
@ -334,8 +326,6 @@ func TestClient_Sign(t *testing.T) {
|
|||
NotBefore: api.NewTimeDuration(time.Now()),
|
||||
NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)),
|
||||
}
|
||||
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -343,11 +333,12 @@ func TestClient_Sign(t *testing.T) {
|
|||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
expectedErr error
|
||||
}{
|
||||
{"ok", request, ok, 200, false},
|
||||
{"unauthorized", request, unauthorized, 401, true},
|
||||
{"empty request", &api.SignRequest{}, badRequest, 403, true},
|
||||
{"nil request", nil, badRequest, 403, true},
|
||||
{"ok", request, ok, 200, false, nil},
|
||||
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -364,7 +355,9 @@ func TestClient_Sign(t *testing.T) {
|
|||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body := new(api.SignRequest)
|
||||
if err := api.ReadJSON(req.Body, body); err != nil {
|
||||
api.WriteError(w, badRequest)
|
||||
e, ok := tt.response.(error)
|
||||
assert.Fatal(t, ok, "response expected to be error type")
|
||||
api.WriteError(w, e)
|
||||
return
|
||||
} else if !equalJSON(t, body, tt.request) {
|
||||
if tt.request == nil {
|
||||
|
@ -390,9 +383,7 @@ func TestClient_Sign(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.Sign() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Sign() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
|
||||
|
@ -409,19 +400,17 @@ func TestClient_Revoke(t *testing.T) {
|
|||
OTT: "the-ott",
|
||||
ReasonCode: 4,
|
||||
}
|
||||
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request *api.RevokeRequest
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
expectedErr error
|
||||
}{
|
||||
{"ok", request, ok, 200, false},
|
||||
{"unauthorized", request, unauthorized, 401, true},
|
||||
{"nil request", nil, badRequest, 403, true},
|
||||
{"ok", request, ok, 200, false, nil},
|
||||
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -438,7 +427,9 @@ func TestClient_Revoke(t *testing.T) {
|
|||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body := new(api.RevokeRequest)
|
||||
if err := api.ReadJSON(req.Body, body); err != nil {
|
||||
api.WriteError(w, badRequest)
|
||||
e, ok := tt.response.(error)
|
||||
assert.Fatal(t, ok, "response expected to be error type")
|
||||
api.WriteError(w, e)
|
||||
return
|
||||
} else if !equalJSON(t, body, tt.request) {
|
||||
if tt.request == nil {
|
||||
|
@ -464,9 +455,7 @@ func TestClient_Revoke(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.Revoke() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Revoke() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
|
||||
|
@ -485,19 +474,18 @@ func TestClient_Renew(t *testing.T) {
|
|||
{Certificate: parseCertificate(rootPEM)},
|
||||
},
|
||||
}
|
||||
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
err error
|
||||
}{
|
||||
{"ok", ok, 200, false},
|
||||
{"unauthorized", unauthorized, 401, true},
|
||||
{"empty request", badRequest, 403, true},
|
||||
{"nil request", badRequest, 403, true},
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -527,9 +515,11 @@ func TestClient_Renew(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.Renew() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Renew() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
|
||||
|
@ -543,7 +533,7 @@ func TestClient_Provisioners(t *testing.T) {
|
|||
ok := &api.ProvisionersResponse{
|
||||
Provisioners: provisioner.List{},
|
||||
}
|
||||
internalServerError := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
|
||||
internalServerError := errs.InternalServer("Internal Server Error")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -589,9 +579,7 @@ func TestClient_Provisioners(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.Provisioners() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Provisioners() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
|
||||
|
@ -605,7 +593,6 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|||
ok := &api.ProvisionerKeyResponse{
|
||||
Key: "an encrypted key",
|
||||
}
|
||||
notFound := errs.NotFound(fmt.Errorf("Not Found"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -613,9 +600,10 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
err error
|
||||
}{
|
||||
{"ok", "kid", ok, 200, false},
|
||||
{"fail", "invalid", notFound, 500, true},
|
||||
{"ok", "kid", ok, 200, false, nil},
|
||||
{"fail", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -648,9 +636,11 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.ProvisionerKey() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
|
||||
|
@ -666,19 +656,17 @@ func TestClient_Roots(t *testing.T) {
|
|||
{Certificate: parseCertificate(rootPEM)},
|
||||
},
|
||||
}
|
||||
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
err error
|
||||
}{
|
||||
{"ok", ok, 200, false},
|
||||
{"unauthorized", unauthorized, 401, true},
|
||||
{"empty request", badRequest, 403, true},
|
||||
{"nil request", badRequest, 403, true},
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -708,9 +696,10 @@ func TestClient_Roots(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.Roots() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Roots() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
|
||||
|
@ -726,19 +715,16 @@ func TestClient_Federation(t *testing.T) {
|
|||
{Certificate: parseCertificate(rootPEM)},
|
||||
},
|
||||
}
|
||||
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
err error
|
||||
}{
|
||||
{"ok", ok, 200, false},
|
||||
{"unauthorized", unauthorized, 401, true},
|
||||
{"empty request", badRequest, 403, true},
|
||||
{"nil request", badRequest, 403, true},
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -768,9 +754,10 @@ func TestClient_Federation(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.Federation() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Federation() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
|
||||
|
@ -790,16 +777,16 @@ func TestClient_SSHRoots(t *testing.T) {
|
|||
HostKeys: []api.SSHPublicKey{{PublicKey: key}},
|
||||
UserKeys: []api.SSHPublicKey{{PublicKey: key}},
|
||||
}
|
||||
notFound := errs.NotFound(fmt.Errorf("Not Found"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
err error
|
||||
}{
|
||||
{"ok", ok, 200, false},
|
||||
{"not found", notFound, 404, true},
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"not found", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -829,9 +816,10 @@ func TestClient_SSHRoots(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.SSHKeys() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.SSHKeys() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response)
|
||||
|
@ -881,7 +869,7 @@ func Test_parseEndpoint(t *testing.T) {
|
|||
|
||||
func TestClient_RootFingerprint(t *testing.T) {
|
||||
ok := &api.HealthResponse{Status: "ok"}
|
||||
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
|
||||
nok := errs.InternalServer("Internal Server Error")
|
||||
|
||||
httpsServer := httptest.NewTLSServer(nil)
|
||||
defer httpsServer.Close()
|
||||
|
@ -948,7 +936,6 @@ func TestClient_SSHBastion(t *testing.T) {
|
|||
Hostname: "bastion.local",
|
||||
},
|
||||
}
|
||||
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -956,11 +943,11 @@ func TestClient_SSHBastion(t *testing.T) {
|
|||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
err error
|
||||
}{
|
||||
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false},
|
||||
{"bad response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true},
|
||||
{"empty request", &api.SSHBastionRequest{}, badRequest, 403, true},
|
||||
{"nil request", nil, badRequest, 403, true},
|
||||
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
|
||||
{"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
|
||||
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -990,8 +977,11 @@ func TestClient_SSHBastion(t *testing.T) {
|
|||
if got != nil {
|
||||
t.Errorf("Client.SSHBastion() = %v, want nil", got)
|
||||
}
|
||||
if tt.responseCode != 200 && !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.SSHBastion() error = %v, want %v", err, tt.response)
|
||||
if tt.responseCode != 200 {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||
}
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
|
|
|
@ -276,6 +276,7 @@ func TestIdentity_Renew(t *testing.T) {
|
|||
}
|
||||
|
||||
oldIdentityDir := identityDir
|
||||
identityDir = "testdata/identity"
|
||||
defer func() {
|
||||
identityDir = oldIdentityDir
|
||||
os.RemoveAll(tmpDir)
|
||||
|
|
|
@ -40,7 +40,7 @@ func (c *mutableTLSConfig) Init(base *tls.Config) {
|
|||
// tls.Config GetConfigForClient.
|
||||
func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) {
|
||||
c.RLock()
|
||||
config = c.config
|
||||
config = c.config.Clone()
|
||||
c.RUnlock()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -80,7 +80,9 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
|
|||
func (r *TLSRenewer) Run() {
|
||||
cert := r.getCertificate()
|
||||
next := r.nextRenewDuration(cert.Leaf.NotAfter)
|
||||
r.Lock()
|
||||
r.timer = time.AfterFunc(next, r.renewCertificate)
|
||||
r.Unlock()
|
||||
}
|
||||
|
||||
// RunContext starts the certificate renewer for the given certificate.
|
||||
|
|
99
db/db.go
99
db/db.go
|
@ -270,6 +270,105 @@ func (db *DB) Shutdown() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// MockAuthDB mocks the AuthDB interface. //
|
||||
type MockAuthDB struct {
|
||||
Err error
|
||||
Ret1 interface{}
|
||||
MIsRevoked func(string) (bool, error)
|
||||
MIsSSHRevoked func(string) (bool, error)
|
||||
MRevoke func(rci *RevokedCertificateInfo) error
|
||||
MRevokeSSH func(rci *RevokedCertificateInfo) error
|
||||
MStoreCertificate func(crt *x509.Certificate) error
|
||||
MUseToken func(id, tok string) (bool, error)
|
||||
MIsSSHHost func(principal string) (bool, error)
|
||||
MStoreSSHCertificate func(crt *ssh.Certificate) error
|
||||
MGetSSHHostPrincipals func() ([]string, error)
|
||||
MShutdown func() error
|
||||
}
|
||||
|
||||
// IsRevoked mock.
|
||||
func (m *MockAuthDB) IsRevoked(sn string) (bool, error) {
|
||||
if m.MIsRevoked != nil {
|
||||
return m.MIsRevoked(sn)
|
||||
}
|
||||
return m.Ret1.(bool), m.Err
|
||||
}
|
||||
|
||||
// IsSSHRevoked mock.
|
||||
func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) {
|
||||
if m.MIsSSHRevoked != nil {
|
||||
return m.MIsSSHRevoked(sn)
|
||||
}
|
||||
return m.Ret1.(bool), m.Err
|
||||
}
|
||||
|
||||
// UseToken mock.
|
||||
func (m *MockAuthDB) UseToken(id, tok string) (bool, error) {
|
||||
if m.MUseToken != nil {
|
||||
return m.MUseToken(id, tok)
|
||||
}
|
||||
if m.Ret1 == nil {
|
||||
return false, m.Err
|
||||
}
|
||||
return m.Ret1.(bool), m.Err
|
||||
}
|
||||
|
||||
// Revoke mock.
|
||||
func (m *MockAuthDB) Revoke(rci *RevokedCertificateInfo) error {
|
||||
if m.MRevoke != nil {
|
||||
return m.MRevoke(rci)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// RevokeSSH mock.
|
||||
func (m *MockAuthDB) RevokeSSH(rci *RevokedCertificateInfo) error {
|
||||
if m.MRevokeSSH != nil {
|
||||
return m.MRevokeSSH(rci)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// StoreCertificate mock.
|
||||
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
|
||||
if m.MStoreCertificate != nil {
|
||||
return m.MStoreCertificate(crt)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// IsSSHHost mock.
|
||||
func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) {
|
||||
if m.MIsSSHHost != nil {
|
||||
return m.MIsSSHHost(principal)
|
||||
}
|
||||
return m.Ret1.(bool), m.Err
|
||||
}
|
||||
|
||||
// StoreSSHCertificate mock.
|
||||
func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
||||
if m.MStoreSSHCertificate != nil {
|
||||
return m.MStoreSSHCertificate(crt)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// GetSSHHostPrincipals mock.
|
||||
func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) {
|
||||
if m.MGetSSHHostPrincipals != nil {
|
||||
return m.MGetSSHHostPrincipals()
|
||||
}
|
||||
return m.Ret1.([]string), m.Err
|
||||
}
|
||||
|
||||
// Shutdown mock.
|
||||
func (m *MockAuthDB) Shutdown() error {
|
||||
if m.MShutdown != nil {
|
||||
return m.MShutdown()
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// MockNoSQLDB //
|
||||
type MockNoSQLDB struct {
|
||||
Err error
|
||||
|
|
255
errs/error.go
255
errs/error.go
|
@ -21,9 +21,9 @@ type StackTracer interface {
|
|||
// Option modifies the Error type.
|
||||
type Option func(e *Error) error
|
||||
|
||||
// WithMessage returns an Option that modifies the error by overwriting the
|
||||
// withDefaultMessage returns an Option that modifies the error by overwriting the
|
||||
// message only if it is empty.
|
||||
func WithMessage(format string, args ...interface{}) Option {
|
||||
func withDefaultMessage(format string, args ...interface{}) Option {
|
||||
return func(e *Error) error {
|
||||
if len(e.Msg) > 0 {
|
||||
return e
|
||||
|
@ -33,31 +33,33 @@ func WithMessage(format string, args ...interface{}) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithMessage returns an Option that modifies the error by overwriting the
|
||||
// message only if it is empty.
|
||||
func WithMessage(format string, args ...interface{}) Option {
|
||||
return func(e *Error) error {
|
||||
e.Msg = fmt.Sprintf(format, args...)
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
||||
// WithKeyVal returns an Option that adds the given key-value pair to the
|
||||
// Error details. This is helpful for debugging errors.
|
||||
func WithKeyVal(key string, val interface{}) Option {
|
||||
return func(e *Error) error {
|
||||
if e.Details == nil {
|
||||
e.Details = make(map[string]interface{})
|
||||
}
|
||||
e.Details[key] = val
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
||||
// Error represents the CA API errors.
|
||||
type Error struct {
|
||||
Status int
|
||||
Err error
|
||||
Msg string
|
||||
}
|
||||
|
||||
// New returns a new Error. If the given error implements the StatusCoder
|
||||
// interface we will ignore the given status.
|
||||
func New(status int, err error, opts ...Option) error {
|
||||
var e *Error
|
||||
if sc, ok := err.(StatusCoder); ok {
|
||||
e = &Error{Status: sc.StatusCode(), Err: err}
|
||||
} else {
|
||||
cause := errors.Cause(err)
|
||||
if sc, ok := cause.(StatusCoder); ok {
|
||||
e = &Error{Status: sc.StatusCode(), Err: err}
|
||||
} else {
|
||||
e = &Error{Status: status, Err: err}
|
||||
}
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(e)
|
||||
}
|
||||
return e
|
||||
Details map[string]interface{}
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error in JSON format.
|
||||
|
@ -92,10 +94,11 @@ func (e *Error) Message() string {
|
|||
|
||||
// Wrap returns an error annotating err with a stack trace at the point Wrap is
|
||||
// called, and the supplied message. If err is nil, Wrap returns nil.
|
||||
func Wrap(status int, e error, m string, opts ...Option) error {
|
||||
func Wrap(status int, e error, m string, args ...interface{}) error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
_, opts := splitOptionArgs(args)
|
||||
if err, ok := e.(*Error); ok {
|
||||
err.Err = errors.Wrap(err.Err, m)
|
||||
e = err
|
||||
|
@ -111,25 +114,12 @@ func Wrapf(status int, e error, format string, args ...interface{}) error {
|
|||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
var opts []Option
|
||||
for i, arg := range args {
|
||||
// Once we find the first Option, assume that all further arguments are Options.
|
||||
if _, ok := arg.(Option); ok {
|
||||
for _, a := range args[i:] {
|
||||
// Ignore any arguments after the first Option that are not Options.
|
||||
if opt, ok := a.(Option); ok {
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
args = args[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
as, opts := splitOptionArgs(args)
|
||||
if err, ok := e.(*Error); ok {
|
||||
err.Err = errors.Wrapf(err.Err, format, args...)
|
||||
e = err
|
||||
} else {
|
||||
e = errors.Wrapf(e, format, args...)
|
||||
e = errors.Wrapf(e, format, as...)
|
||||
}
|
||||
return StatusCodeError(status, e, opts...)
|
||||
}
|
||||
|
@ -174,77 +164,172 @@ type Messenger interface {
|
|||
func StatusCodeError(code int, e error, opts ...Option) error {
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
return BadRequest(e, opts...)
|
||||
return BadRequestErr(e, opts...)
|
||||
case http.StatusUnauthorized:
|
||||
return Unauthorized(e, opts...)
|
||||
return UnauthorizedErr(e, opts...)
|
||||
case http.StatusForbidden:
|
||||
return Forbidden(e, opts...)
|
||||
return ForbiddenErr(e, opts...)
|
||||
case http.StatusInternalServerError:
|
||||
return InternalServerError(e, opts...)
|
||||
return InternalServerErr(e, opts...)
|
||||
case http.StatusNotImplemented:
|
||||
return NotImplemented(e, opts...)
|
||||
return NotImplementedErr(e, opts...)
|
||||
default:
|
||||
return UnexpectedError(code, e, opts...)
|
||||
return UnexpectedErr(code, e, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
var seeLogs = "Please see the certificate authority logs for more info."
|
||||
var (
|
||||
seeLogs = "Please see the certificate authority logs for more info."
|
||||
// BadRequestDefaultMsg 400 default msg
|
||||
BadRequestDefaultMsg = "The request could not be completed; malformed or missing data" + seeLogs
|
||||
// UnauthorizedDefaultMsg 401 default msg
|
||||
UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs
|
||||
// ForbiddenDefaultMsg 403 default msg
|
||||
ForbiddenDefaultMsg = "The request was forbidden by the certificate authority. " + seeLogs
|
||||
// NotFoundDefaultMsg 404 default msg
|
||||
NotFoundDefaultMsg = "The requested resource could not be found. " + seeLogs
|
||||
// InternalServerErrorDefaultMsg 500 default msg
|
||||
InternalServerErrorDefaultMsg = "The certificate authority encountered an Internal Server Error. " + seeLogs
|
||||
// NotImplementedDefaultMsg 501 default msg
|
||||
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
|
||||
)
|
||||
|
||||
// InternalServerError returns a 500 error with the given error.
|
||||
func InternalServerError(err error, opts ...Option) error {
|
||||
if len(opts) == 0 {
|
||||
opts = append(opts, WithMessage("The certificate authority encountered an Internal Server Error. "+seeLogs))
|
||||
// splitOptionArgs splits the variadic length args into string formatting args
|
||||
// and Option(s) to apply to an Error.
|
||||
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
|
||||
indexOptionStart := -1
|
||||
for i, a := range args {
|
||||
if _, ok := a.(Option); ok {
|
||||
indexOptionStart = i
|
||||
break
|
||||
}
|
||||
return New(http.StatusInternalServerError, err, opts...)
|
||||
}
|
||||
|
||||
if indexOptionStart < 0 {
|
||||
return args, []Option{}
|
||||
}
|
||||
opts := []Option{}
|
||||
// Ignore any non-Option args that come after the first Option.
|
||||
for _, o := range args[indexOptionStart:] {
|
||||
if opt, ok := o.(Option); ok {
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
return args[:indexOptionStart], opts
|
||||
}
|
||||
|
||||
// NotImplemented returns a 501 error with the given error.
|
||||
func NotImplemented(err error, opts ...Option) error {
|
||||
if len(opts) == 0 {
|
||||
opts = append(opts, WithMessage("The requested method is not implemented by the certificate authority. "+seeLogs))
|
||||
// NewErr returns a new Error. If the given error implements the StatusCoder
|
||||
// interface we will ignore the given status.
|
||||
func NewErr(status int, err error, opts ...Option) error {
|
||||
var (
|
||||
e *Error
|
||||
ok bool
|
||||
)
|
||||
if e, ok = err.(*Error); !ok {
|
||||
if sc, ok := err.(StatusCoder); ok {
|
||||
e = &Error{Status: sc.StatusCode(), Err: err}
|
||||
} else {
|
||||
cause := errors.Cause(err)
|
||||
if sc, ok := cause.(StatusCoder); ok {
|
||||
e = &Error{Status: sc.StatusCode(), Err: err}
|
||||
} else {
|
||||
e = &Error{Status: status, Err: err}
|
||||
}
|
||||
return New(http.StatusNotImplemented, err, opts...)
|
||||
}
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// BadRequest returns an 400 error with the given error.
|
||||
func BadRequest(err error, opts ...Option) error {
|
||||
if len(opts) == 0 {
|
||||
opts = append(opts, WithMessage("The request could not be completed due to being poorly formatted or "+
|
||||
"missing critical data. "+seeLogs))
|
||||
// Errorf creates a new error using the given format and status code.
|
||||
func Errorf(code int, format string, args ...interface{}) error {
|
||||
as, opts := splitOptionArgs(args)
|
||||
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
|
||||
e := &Error{Status: code, Err: fmt.Errorf(format, as...)}
|
||||
for _, o := range opts {
|
||||
o(e)
|
||||
}
|
||||
return New(http.StatusBadRequest, err, opts...)
|
||||
return e
|
||||
}
|
||||
|
||||
// Unauthorized returns an 401 error with the given error.
|
||||
func Unauthorized(err error, opts ...Option) error {
|
||||
if len(opts) == 0 {
|
||||
opts = append(opts, WithMessage("The request lacked necessary authorization to be completed. "+seeLogs))
|
||||
}
|
||||
return New(http.StatusUnauthorized, err, opts...)
|
||||
// InternalServer creates a 500 error with the given format and arguments.
|
||||
func InternalServer(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg))
|
||||
return Errorf(http.StatusInternalServerError, format, args...)
|
||||
}
|
||||
|
||||
// Forbidden returns an 403 error with the given error.
|
||||
func Forbidden(err error, opts ...Option) error {
|
||||
if len(opts) == 0 {
|
||||
opts = append(opts, WithMessage("The request was Forbidden by the certificate authority. "+seeLogs))
|
||||
}
|
||||
return New(http.StatusForbidden, err, opts...)
|
||||
// InternalServerErr returns a 500 error with the given error.
|
||||
func InternalServerErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg))
|
||||
return NewErr(http.StatusInternalServerError, err, opts...)
|
||||
}
|
||||
|
||||
// NotFound returns an 404 error with the given error.
|
||||
func NotFound(err error, opts ...Option) error {
|
||||
if len(opts) == 0 {
|
||||
opts = append(opts, WithMessage("The requested resource could not be found. "+seeLogs))
|
||||
}
|
||||
return New(http.StatusNotFound, err, opts...)
|
||||
// NotImplemented creates a 501 error with the given format and arguments.
|
||||
func NotImplemented(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(NotImplementedDefaultMsg))
|
||||
return Errorf(http.StatusNotImplemented, format, args...)
|
||||
}
|
||||
|
||||
// UnexpectedError will be used when the certificate authority makes an outgoing
|
||||
// NotImplementedErr returns a 501 error with the given error.
|
||||
func NotImplementedErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
|
||||
return NewErr(http.StatusNotImplemented, err, opts...)
|
||||
}
|
||||
|
||||
// BadRequest creates a 400 error with the given format and arguments.
|
||||
func BadRequest(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(BadRequestDefaultMsg))
|
||||
return Errorf(http.StatusBadRequest, format, args...)
|
||||
}
|
||||
|
||||
// BadRequestErr returns an 400 error with the given error.
|
||||
func BadRequestErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
|
||||
return NewErr(http.StatusBadRequest, err, opts...)
|
||||
}
|
||||
|
||||
// Unauthorized creates a 401 error with the given format and arguments.
|
||||
func Unauthorized(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(UnauthorizedDefaultMsg))
|
||||
return Errorf(http.StatusUnauthorized, format, args...)
|
||||
}
|
||||
|
||||
// UnauthorizedErr returns an 401 error with the given error.
|
||||
func UnauthorizedErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg))
|
||||
return NewErr(http.StatusUnauthorized, err, opts...)
|
||||
}
|
||||
|
||||
// Forbidden creates a 403 error with the given format and arguments.
|
||||
func Forbidden(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(ForbiddenDefaultMsg))
|
||||
return Errorf(http.StatusForbidden, format, args...)
|
||||
}
|
||||
|
||||
// ForbiddenErr returns an 403 error with the given error.
|
||||
func ForbiddenErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg))
|
||||
return NewErr(http.StatusForbidden, err, opts...)
|
||||
}
|
||||
|
||||
// NotFound creates a 404 error with the given format and arguments.
|
||||
func NotFound(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(NotFoundDefaultMsg))
|
||||
return Errorf(http.StatusNotFound, format, args...)
|
||||
}
|
||||
|
||||
// NotFoundErr returns an 404 error with the given error.
|
||||
func NotFoundErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(NotFoundDefaultMsg))
|
||||
return NewErr(http.StatusNotFound, err, opts...)
|
||||
}
|
||||
|
||||
// UnexpectedErr will be used when the certificate authority makes an outgoing
|
||||
// request and receives an unhandled status code.
|
||||
func UnexpectedError(code int, err error, opts ...Option) error {
|
||||
if len(opts) == 0 {
|
||||
opts = append(opts, WithMessage("The certificate authority received an "+
|
||||
func UnexpectedErr(code int, err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage("The certificate authority received an "+
|
||||
"unexpected HTTP status code - '%d'. "+seeLogs, code))
|
||||
}
|
||||
return New(code, err, opts...)
|
||||
return NewErr(code, err, opts...)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
package api
|
||||
package errs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
func TestError_MarshalJSON(t *testing.T) {
|
||||
|
@ -24,7 +22,7 @@ func TestError_MarshalJSON(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &errs.Error{
|
||||
e := &Error{
|
||||
Status: tt.fields.Status,
|
||||
Err: tt.fields.Err,
|
||||
}
|
||||
|
@ -47,15 +45,15 @@ func TestError_UnmarshalJSON(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
expected *errs.Error
|
||||
expected *Error
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{[]byte(`{"status":400,"message":"bad request"}`)}, &errs.Error{Status: 400, Err: fmt.Errorf("bad request")}, false},
|
||||
{"fail", args{[]byte(`{"status":"400","message":"bad request"}`)}, &errs.Error{}, true},
|
||||
{"ok", args{[]byte(`{"status":400,"message":"bad request"}`)}, &Error{Status: 400, Err: fmt.Errorf("bad request")}, false},
|
||||
{"fail", args{[]byte(`{"status":"400","message":"bad request"}`)}, &Error{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := new(errs.Error)
|
||||
e := new(Error)
|
||||
if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
Loading…
Reference in a new issue