diff --git a/api/api.go b/api/api.go index 33aa0f44..c4b307b3 100644 --- a/api/api.go +++ b/api/api.go @@ -5,7 +5,6 @@ import ( "crypto/dsa" "crypto/ecdsa" "crypto/rsa" - "crypto/tls" "crypto/x509" "encoding/asn1" "encoding/base64" @@ -209,14 +208,6 @@ type RootResponse struct { RootPEM Certificate `json:"ca"` } -// SignRequest is the request body for a certificate signature request. -type SignRequest struct { - CsrPEM CertificateRequest `json:"csr"` - OTT string `json:"ott"` - NotAfter TimeDuration `json:"notAfter"` - NotBefore TimeDuration `json:"notBefore"` -} - // ProvisionersResponse is the response object that returns the list of // provisioners. type ProvisionersResponse struct { @@ -230,31 +221,6 @@ type ProvisionerKeyResponse struct { Key string `json:"key"` } -// Validate checks the fields of the SignRequest and returns nil if they are ok -// or an error if something is wrong. -func (s *SignRequest) Validate() error { - if s.CsrPEM.CertificateRequest == nil { - return errs.BadRequest(errors.New("missing csr")) - } - if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { - return errs.BadRequest(errors.Wrap(err, "invalid csr")) - } - if s.OTT == "" { - return errs.BadRequest(errors.New("missing ott")) - } - - return nil -} - -// SignResponse is the response object of the certificate signature request. -type SignResponse struct { - ServerPEM Certificate `json:"crt"` - CaPEM Certificate `json:"ca"` - CertChainPEM []Certificate `json:"certChain"` - TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"` - TLS *tls.ConnectionState `json:"-"` -} - // RootsResponse is the response object of the roots request. type RootsResponse struct { Certificates []Certificate `json:"crts"` @@ -344,80 +310,6 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { return certChainPEM } -// Sign is an HTTP handler that reads a certificate request and an -// one-time-token (ott) from the body and creates a new certificate with the -// information in the certificate request. -func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { - var body SignRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) - return - } - - logOtt(w, body.OTT) - if err := body.Validate(); err != nil { - WriteError(w, err) - return - } - - opts := provisioner.Options{ - NotBefore: body.NotBefore, - NotAfter: body.NotAfter, - } - - signOpts, err := h.Authority.AuthorizeSign(body.OTT) - if err != nil { - WriteError(w, errs.Unauthorized(err)) - return - } - - certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) - if err != nil { - WriteError(w, errs.Forbidden(err)) - return - } - certChainPEM := certChainToPEM(certChain) - var caPEM Certificate - if len(certChainPEM) > 0 { - caPEM = certChainPEM[1] - } - logCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ - ServerPEM: certChainPEM[0], - CaPEM: caPEM, - CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), - }, http.StatusCreated) -} - -// Renew uses the information of certificate in the TLS connection to create a -// new one. -func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest(errors.New("missing peer certificate"))) - return - } - - certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) - if err != nil { - WriteError(w, errs.Forbidden(err)) - return - } - certChainPEM := certChainToPEM(certChain) - var caPEM Certificate - if len(certChainPEM) > 0 { - caPEM = certChainPEM[1] - } - - logCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ - ServerPEM: certChainPEM[0], - CaPEM: caPEM, - CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), - }, http.StatusCreated) -} - // Provisioners returns the list of provisioners configured in the authority. func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := parseCursor(r) diff --git a/api/api_test.go b/api/api_test.go index 70ba6a89..9f40a8e0 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -28,6 +28,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/sshutil" "github.com/smallstep/certificates/templates" @@ -914,7 +915,7 @@ func Test_caHandler_Renew(t *testing.T) { {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, - {"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + {"renew error", cs, nil, nil, errs.Forbidden(fmt.Errorf("an error")), http.StatusForbidden}, } expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) @@ -934,13 +935,13 @@ func Test_caHandler_Renew(t *testing.T) { res := w.Result() if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { - t.Errorf("caHandler.Root unexpected error = %v", err) + t.Errorf("caHandler.Renew unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { @@ -1009,8 +1010,12 @@ func Test_caHandler_Provisioners(t *testing.T) { t.Fatal(err) } - expectedError400 := []byte(`{"status":400,"message":"Bad Request"}`) - expectedError500 := []byte(`{"status":500,"message":"Internal Server Error"}`) + expectedError400 := errs.BadRequest(errors.New("force")) + expectedError400Bytes, err := json.Marshal(expectedError400) + assert.FatalError(t, err) + expectedError500 := errs.InternalServerError(errors.New("force")) + expectedError500Bytes, err := json.Marshal(expectedError500) + assert.FatalError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &caHandler{ @@ -1035,12 +1040,12 @@ func Test_caHandler_Provisioners(t *testing.T) { } else { switch tt.statusCode { case 400: - if !bytes.Equal(bytes.TrimSpace(body), expectedError400) { - t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400) + if !bytes.Equal(bytes.TrimSpace(body), expectedError400Bytes) { + t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400Bytes) } case 500: - if !bytes.Equal(bytes.TrimSpace(body), expectedError500) { - t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500) + if !bytes.Equal(bytes.TrimSpace(body), expectedError500Bytes) { + t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500Bytes) } default: t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode) @@ -1077,7 +1082,9 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { } expected := []byte(`{"key":"` + privKey + `"}`) - expectedError := []byte(`{"status":404,"message":"Not Found"}`) + expectedError404 := errs.NotFound(errors.New("force")) + expectedError404Bytes, err := json.Marshal(expectedError404) + assert.FatalError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1101,8 +1108,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) } } else { - if !bytes.Equal(bytes.TrimSpace(body), expectedError) { - t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError) + if !bytes.Equal(bytes.TrimSpace(body), expectedError404Bytes) { + t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError404Bytes) } } }) diff --git a/api/renew.go b/api/renew.go new file mode 100644 index 00000000..bc42ec24 --- /dev/null +++ b/api/renew.go @@ -0,0 +1,36 @@ +package api + +import ( + "net/http" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" +) + +// Renew uses the information of certificate in the TLS connection to create a +// new one. +func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + WriteError(w, errs.BadRequest(errors.New("missing peer certificate"))) + return + } + + certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) + if err != nil { + WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) + return + } + certChainPEM := certChainToPEM(certChain) + var caPEM Certificate + if len(certChainPEM) > 0 { + caPEM = certChainPEM[1] + } + + logCertificate(w, certChain[0]) + JSONStatus(w, &SignResponse{ + ServerPEM: certChainPEM[0], + CaPEM: caPEM, + CertChainPEM: certChainPEM, + TLSOptions: h.Authority.GetTLSOptions(), + }, http.StatusCreated) +} diff --git a/api/sign.go b/api/sign.go new file mode 100644 index 00000000..e76f6256 --- /dev/null +++ b/api/sign.go @@ -0,0 +1,90 @@ +package api + +import ( + "crypto/tls" + "net/http" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" + "github.com/smallstep/cli/crypto/tlsutil" +) + +// SignRequest is the request body for a certificate signature request. +type SignRequest struct { + CsrPEM CertificateRequest `json:"csr"` + OTT string `json:"ott"` + NotAfter TimeDuration `json:"notAfter"` + NotBefore TimeDuration `json:"notBefore"` +} + +// Validate checks the fields of the SignRequest and returns nil if they are ok +// or an error if something is wrong. +func (s *SignRequest) Validate() error { + if s.CsrPEM.CertificateRequest == nil { + return errs.BadRequest(errors.New("missing csr")) + } + if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { + return errs.BadRequest(errors.Wrap(err, "invalid csr")) + } + if s.OTT == "" { + return errs.BadRequest(errors.New("missing ott")) + } + + return nil +} + +// SignResponse is the response object of the certificate signature request. +type SignResponse struct { + ServerPEM Certificate `json:"crt"` + CaPEM Certificate `json:"ca"` + CertChainPEM []Certificate `json:"certChain"` + TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"` + TLS *tls.ConnectionState `json:"-"` +} + +// Sign is an HTTP handler that reads a certificate request and an +// one-time-token (ott) from the body and creates a new certificate with the +// information in the certificate request. +func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { + var body SignRequest + if err := ReadJSON(r.Body, &body); err != nil { + WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) + return + } + + logOtt(w, body.OTT) + if err := body.Validate(); err != nil { + WriteError(w, err) + return + } + + opts := provisioner.Options{ + NotBefore: body.NotBefore, + NotAfter: body.NotAfter, + } + + signOpts, err := h.Authority.AuthorizeSign(body.OTT) + if err != nil { + WriteError(w, errs.Unauthorized(err)) + return + } + + certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) + if err != nil { + WriteError(w, errs.Forbidden(err)) + return + } + certChainPEM := certChainToPEM(certChain) + var caPEM Certificate + if len(certChainPEM) > 0 { + caPEM = certChainPEM[1] + } + logCertificate(w, certChain[0]) + JSONStatus(w, &SignResponse{ + ServerPEM: certChainPEM[0], + CaPEM: caPEM, + CertChainPEM: certChainPEM, + TLSOptions: h.Authority.GetTLSOptions(), + }, http.StatusCreated) +} diff --git a/api/ssh.go b/api/ssh.go index f125a95a..2206973b 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -282,7 +282,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ValidAfter: body.ValidAfter, } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod) + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { WriteError(w, errs.Unauthorized(err)) diff --git a/api/sshRekey.go b/api/sshRekey.go index aa70cf4f..efeee141 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -56,13 +56,13 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { return } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RekeySSHMethod) + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { WriteError(w, errs.Unauthorized(err)) return } - oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT) + oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { WriteError(w, errs.InternalServerError(err)) } diff --git a/api/sshRenew.go b/api/sshRenew.go index 5165bf33..fd4ff1ee 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -46,13 +46,13 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { return } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RenewSSHMethod) + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod) _, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { WriteError(w, errs.Unauthorized(err)) return } - oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT) + oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { WriteError(w, errs.InternalServerError(err)) } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index 93e0e450..cd4a3a3e 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -66,7 +66,7 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { PassiveOnly: body.Passive, } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeSSHMethod) + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod) // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) diff --git a/authority/authority_test.go b/authority/authority_test.go index ee517517..e6a65453 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -13,12 +13,13 @@ import ( stepJOSE "github.com/smallstep/cli/jose" ) -func testAuthority(t *testing.T) *Authority { +func testAuthority(t *testing.T, opts ...Option) *Authority { maxjwk, err := stepJOSE.ParseKey("testdata/secrets/max_pub.jwk") assert.FatalError(t, err) clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) disableRenewal := true + enableSSHCA := true p := provisioner.List{ &provisioner.JWK{ Name: "Max", @@ -29,6 +30,9 @@ func testAuthority(t *testing.T) *Authority { Name: "step-cli", Type: "JWK", Key: clijwk, + Claims: &provisioner.Claims{ + EnableSSHCA: &enableSSHCA, + }, }, &provisioner.JWK{ Name: "dev", @@ -46,19 +50,30 @@ func testAuthority(t *testing.T) *Authority { DisableRenewal: &disableRenewal, }, }, + &provisioner.SSHPOP{ + Name: "sshpop", + Type: "SSHPOP", + Claims: &provisioner.Claims{ + EnableSSHCA: &enableSSHCA, + }, + }, } c := &Config{ Address: "127.0.0.1:443", Root: []string{"testdata/certs/root_ca.crt"}, IntermediateCert: "testdata/certs/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", - DNSNames: []string{"test.ca.smallstep.com"}, - Password: "pass", + SSH: &SSHConfig{ + HostKey: "testdata/secrets/ssh_host_ca_key", + UserKey: "testdata/secrets/ssh_user_ca_key", + }, + DNSNames: []string{"example.com"}, + Password: "pass", AuthorityConfig: &AuthConfig{ Provisioners: p, }, } - a, err := New(c) + a, err := New(c, opts...) assert.FatalError(t, err) return a } diff --git a/authority/authorize.go b/authority/authorize.go index 3353c6b1..cdca026d 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -8,7 +8,9 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" + "golang.org/x/crypto/ssh" ) // Claims extends jose.Claims with step attributes. @@ -36,22 +38,19 @@ func SkipTokenReuseFromContext(ctx context.Context) bool { // authorizeToken parses the token and returns the provisioner used to generate // the token. This method enforces the One-Time use policy (tokens can only be // used once). -func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner.Interface, error) { - var errContext = map[string]interface{}{"ott": ott} - +func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) { // Validate payload - token, err := jose.ParseSigned(ott) + tok, err := jose.ParseSigned(token) if err != nil { - return nil, &apiError{errors.Wrapf(err, "authorizeToken: error parsing token"), - http.StatusUnauthorized, errContext} + return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token") } // Get claims w/out verification. We need to look up the provisioner // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. var claims Claims - if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, &apiError{errors.Wrap(err, "authorizeToken"), http.StatusUnauthorized, errContext} + if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken") } // TODO: use new persistence layer abstraction. @@ -59,29 +58,27 @@ func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner // This check is meant as a stopgap solution to the current lack of a persistence layer. if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { - return nil, &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"), - http.StatusUnauthorized, errContext} + return nil, errs.Unauthorized(errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority")) } } // This method will also validate the audiences for JWK provisioners. - p, ok := a.provisioners.LoadByToken(token, &claims.Claims) + p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) if !ok { - return nil, &apiError{ - errors.Errorf("authorizeToken: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")), - http.StatusUnauthorized, errContext} + return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: provisioner "+ + "not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))) } // Store the token to protect against reuse unless it's skipped. if !SkipTokenReuseFromContext(ctx) { - if reuseKey, err := p.GetTokenID(ott); err == nil { - ok, err := a.db.UseToken(reuseKey, ott) + if reuseKey, err := p.GetTokenID(token); err == nil { + ok, err := a.db.UseToken(reuseKey, token) if err != nil { - return nil, &apiError{errors.Wrap(err, "authorizeToken: failed when attempting to store token"), - http.StatusInternalServerError, errContext} + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.authorizeToken: failed when attempting to store token") } if !ok { - return nil, &apiError{errors.Errorf("authorizeToken: token already used"), http.StatusUnauthorized, errContext} + return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: token already used")) } } } @@ -89,125 +86,158 @@ func (a *Authority) authorizeToken(ctx context.Context, ott string) (provisioner return p, nil } -// Authorize grabs the method from the context and authorizes a signature -// request by validating the one-time-token. -func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - var errContext = apiCtx{"ott": ott} +// Authorize grabs the method from the context and authorizes the request by +// validating the one-time-token. +func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) { + var opts = []errs.Option{errs.WithKeyVal("token", token)} + switch m := provisioner.MethodFromContext(ctx); m { case provisioner.SignMethod: - return a.authorizeSign(ctx, ott) + signOpts, err := a.authorizeSign(ctx, token) + return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) case provisioner.RevokeMethod: - return nil, a.authorizeRevoke(ctx, ott) - case provisioner.SignSSHMethod: + return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...) + case provisioner.SSHSignMethod: if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { - return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext} + return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...) } - return a.authorizeSSHSign(ctx, ott) - case provisioner.RenewSSHMethod: + _, err := a.authorizeSSHSign(ctx, token) + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) + case provisioner.SSHRenewMethod: if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { - return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext} + return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...) } - if _, err := a.authorizeSSHRenew(ctx, ott); err != nil { - return nil, err - } - return nil, nil - case provisioner.RevokeSSHMethod: - return nil, a.authorizeSSHRevoke(ctx, ott) - case provisioner.RekeySSHMethod: + _, err := a.authorizeSSHRenew(ctx, token) + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) + case provisioner.SSHRevokeMethod: + return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...) + case provisioner.SSHRekeyMethod: if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { - return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext} + return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...) } - _, opts, err := a.authorizeSSHRekey(ctx, ott) - if err != nil { - return nil, err - } - return opts, nil + _, signOpts, err := a.authorizeSSHRekey(ctx, token) + return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) default: - return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext} + return nil, errs.InternalServerError(errors.Errorf("authority.Authorize; method %d is not supported", m), opts...) } } -// authorizeSign loads the provisioner from the token, checks that it has not -// been used again and calls the provisioner AuthorizeSign method. Returns a -// list of methods to apply to the signing flow. -func (a *Authority) authorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - var errContext = apiCtx{"ott": ott} - p, err := a.authorizeToken(ctx, ott) +// authorizeSign loads the provisioner from the token and calls the provisioner +// AuthorizeSign method. Returns a list of methods to apply to the signing flow. +func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) { + p, err := a.authorizeToken(ctx, token) if err != nil { - return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext} + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign") } - opts, err := p.AuthorizeSign(ctx, ott) + signOpts, err := p.AuthorizeSign(ctx, token) if err != nil { - return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext} + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign") } - return opts, nil + return signOpts, nil } // AuthorizeSign authorizes a signature request by validating and authenticating -// a OTT that must be sent w/ the request. +// a token that must be sent w/ the request. // // NOTE: This method is deprecated and should not be used. We make it available // in the short term os as not to break existing clients. -func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { +func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) { ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) - return a.Authorize(ctx, ott) + return a.Authorize(ctx, token) } -// authorizeRevoke authorizes a revocation request by validating and authenticating -// the RevokeOptions POSTed with the request. -// Returns a tuple of the provisioner ID and error, if one occurred. +// authorizeRevoke locates the provisioner used to generate the authenticating +// token and then performs the token validation flow. func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { - errContext := map[string]interface{}{"ott": token} - p, err := a.authorizeToken(ctx, token) if err != nil { - return &apiError{errors.Wrap(err, "authorizeRevoke"), http.StatusUnauthorized, errContext} + return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") } if err = p.AuthorizeRevoke(ctx, token); err != nil { - return &apiError{errors.Wrap(err, "authorizeRevoke"), http.StatusUnauthorized, errContext} + return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") } return nil } -// authorizeRenewl tries to locate the step provisioner extension, and checks +// authorizeRenew locates the provisioner (using the provisioner extension in the cert), and checks // if for the configured provisioner, the renewal is enabled or not. If the // extra extension cannot be found, authorize the renewal by default. // // TODO(mariano): should we authorize by default? -func (a *Authority) authorizeRenew(crt *x509.Certificate) error { - errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()} +func (a *Authority) authorizeRenew(cert *x509.Certificate) error { + var opts = []errs.Option{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())} // Check the passive revocation table. - isRevoked, err := a.db.IsRevoked(crt.SerialNumber.String()) + isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String()) if err != nil { - return &apiError{ - err: errors.Wrap(err, "renew"), - code: http.StatusInternalServerError, - context: errContext, - } + return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) } if isRevoked { - return &apiError{ - err: errors.New("renew: certificate has been revoked"), - code: http.StatusUnauthorized, - context: errContext, - } + return errs.Unauthorized(errors.New("authority.authorizeRenew: certificate has been revoked"), opts...) } - p, ok := a.provisioners.LoadByCertificate(crt) + p, ok := a.provisioners.LoadByCertificate(cert) if !ok { - return &apiError{ - err: errors.New("renew: provisioner not found"), - code: http.StatusUnauthorized, - context: errContext, - } + return errs.Unauthorized(errors.New("authority.authorizeRenew: provisioner not found"), opts...) } - if err := p.AuthorizeRenew(context.Background(), crt); err != nil { - return &apiError{ - err: errors.Wrap(err, "renew"), - code: http.StatusUnauthorized, - context: errContext, - } + if err := p.AuthorizeRenew(context.Background(), cert); err != nil { + return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) + } + return nil +} + +// authorizeSSHSign loads the provisioner from the token, checks that it has not +// been used again and calls the provisioner AuthorizeSSHSign method. Returns a +// list of methods to apply to the signing flow. +func (a *Authority) authorizeSSHSign(ctx context.Context, token string) ([]provisioner.SignOption, error) { + p, err := a.authorizeToken(ctx, token) + if err != nil { + return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign") + } + signOpts, err := p.AuthorizeSSHSign(ctx, token) + if err != nil { + return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign") + } + return signOpts, nil +} + +// authorizeSSHRenew authorizes an SSH certificate renewal request, by +// validating the contents of an SSHPOP token. +func (a *Authority) authorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { + p, err := a.authorizeToken(ctx, token) + if err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew") + } + cert, err := p.AuthorizeSSHRenew(ctx, token) + if err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew") + } + return cert, nil +} + +// authorizeSSHRekey authorizes an SSH certificate rekey request, by +// validating the contents of an SSHPOP token. +func (a *Authority) authorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) { + p, err := a.authorizeToken(ctx, token) + if err != nil { + return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey") + } + cert, signOpts, err := p.AuthorizeSSHRekey(ctx, token) + if err != nil { + return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey") + } + return cert, signOpts, nil +} + +// authorizeSSHRevoke authorizes an SSH certificate revoke request, by +// validating the contents of an SSHPOP token. +func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error { + p, err := a.authorizeToken(ctx, token) + if err != nil { + return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke") + } + if err = p.AuthorizeSSHRevoke(ctx, token); err != nil { + return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke") } return nil } diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 5e112e95..6f7bf940 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -2,25 +2,58 @@ package authority import ( "context" + "crypto" + "crypto/rand" "crypto/x509" + "encoding/base64" + "fmt" "net/http" + "strconv" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/jose" + "golang.org/x/crypto/ssh" "gopkg.in/square/go-jose.v2/jwt" ) -func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { - sig, err := jose.NewSigner( - jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, - new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), - ) +var testAudiences = provisioner.Audiences{ + Sign: []string{"https://example.com/1.0/sign", "https://example.com/sign"}, + Revoke: []string{"https://example.com/1.0/revoke", "https://example.com/revoke"}, + SSHSign: []string{"https://example.com/1.0/ssh/sign"}, + SSHRevoke: []string{"https://example.com/1.0/ssh/revoke"}, + SSHRenew: []string{"https://example.com/1.0/ssh/renew"}, + SSHRekey: []string{"https://example.com/1.0/ssh/rekey"}, +} + +type tokOption func(*jose.SignerOptions) error + +func withSSHPOPFile(cert *ssh.Certificate) tokOption { + return func(so *jose.SignerOptions) error { + so.WithHeader("sshpop", base64.StdEncoding.EncodeToString(cert.Marshal())) + return nil + } +} + +func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("kid", jwk.KeyID) + + for _, o := range tokOpts { + if err := o(so); err != nil { + return "", err + } + } + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) if err != nil { return "", err } @@ -61,20 +94,21 @@ func TestAuthority_authorizeToken(t *testing.T) { now := time.Now().UTC() validIssuer := "step-cli" - validAudience := []string{"https://test.ca.smallstep.com/revoke"} + validAudience := []string{"https://example.com/revoke"} type authorizeTest struct { - auth *Authority - ott string - err *apiError + auth *Authority + token string + err error + code int } tests := map[string]func(t *testing.T) *authorizeTest{ - "fail/invalid-ott": func(t *testing.T) *authorizeTest { + "fail/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ - auth: a, - ott: "foo", - err: &apiError{errors.New("authorizeToken: error parsing token"), - http.StatusUnauthorized, apiCtx{"ott": "foo"}}, + auth: a, + token: "foo", + err: errors.New("authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, } }, "fail/prehistoric-token": func(t *testing.T) *authorizeTest { @@ -90,10 +124,10 @@ func TestAuthority_authorizeToken(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: a, - ott: raw, - err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"), - http.StatusUnauthorized, apiCtx{"ott": raw}}, + auth: a, + token: raw, + err: errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"), + code: http.StatusUnauthorized, } }, "fail/provisioner-not-found": func(t *testing.T) *authorizeTest { @@ -112,10 +146,10 @@ func TestAuthority_authorizeToken(t *testing.T) { raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: a, - ott: raw, - err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"), - http.StatusUnauthorized, apiCtx{"ott": raw}}, + auth: a, + token: raw, + err: errors.New("authority.authorizeToken: provisioner not found or invalid audience (https://example.com/revoke)"), + code: http.StatusUnauthorized, } }, "ok/simpledb": func(t *testing.T) *authorizeTest { @@ -130,8 +164,8 @@ func TestAuthority_authorizeToken(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: a, - ott: raw, + auth: a, + token: raw, } }, "fail/simpledb/token-already-used": func(t *testing.T) *authorizeTest { @@ -149,16 +183,16 @@ func TestAuthority_authorizeToken(t *testing.T) { _, err = _a.authorizeToken(context.TODO(), raw) assert.FatalError(t, err) return &authorizeTest{ - auth: _a, - ott: raw, - err: &apiError{errors.New("authorizeToken: token already used"), - http.StatusUnauthorized, apiCtx{"ott": raw}}, + auth: _a, + token: raw, + err: errors.New("authority.authorizeToken: token already used"), + code: http.StatusUnauthorized, } }, "ok/mockNoSQLDB": func(t *testing.T) *authorizeTest { _a := testAuthority(t) - _a.db = &MockAuthDB{ - useToken: func(id, tok string) (bool, error) { + _a.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { return true, nil }, } @@ -174,14 +208,14 @@ func TestAuthority_authorizeToken(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: _a, - ott: raw, + auth: _a, + token: raw, } }, "fail/mockNoSQLDB/error": func(t *testing.T) *authorizeTest { _a := testAuthority(t) - _a.db = &MockAuthDB{ - useToken: func(id, tok string) (bool, error) { + _a.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { return false, errors.New("force") }, } @@ -197,16 +231,16 @@ func TestAuthority_authorizeToken(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: _a, - ott: raw, - err: &apiError{errors.New("authorizeToken: failed when checking if token already used: force"), - http.StatusInternalServerError, apiCtx{"ott": raw}}, + auth: _a, + token: raw, + err: errors.New("authority.authorizeToken: failed when attempting to store token: force"), + code: http.StatusInternalServerError, } }, "fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest { _a := testAuthority(t) - _a.db = &MockAuthDB{ - useToken: func(id, tok string) (bool, error) { + _a.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { return false, nil }, } @@ -222,10 +256,10 @@ func TestAuthority_authorizeToken(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: _a, - ott: raw, - err: &apiError{errors.New("authorizeToken: token already used"), - http.StatusUnauthorized, apiCtx{"ott": raw}}, + auth: _a, + token: raw, + err: errors.New("authority.authorizeToken: token already used"), + code: http.StatusUnauthorized, } }, } @@ -234,17 +268,13 @@ func TestAuthority_authorizeToken(t *testing.T) { t.Run(name, func(t *testing.T) { tc := genTestCase(t) - p, err := tc.auth.authorizeToken(context.TODO(), tc.ott) + p, err := tc.auth.authorizeToken(context.TODO(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - switch v := err.(type) { - case *apiError: - assert.HasPrefix(t, v.err.Error(), tc.err.Error()) - assert.Equals(t, v.code, tc.err.code) - assert.Equals(t, v.context, tc.err.context) - default: - t.Errorf("unexpected error type: %T", v) - } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { @@ -268,20 +298,21 @@ func TestAuthority_authorizeRevoke(t *testing.T) { now := time.Now().UTC() validIssuer := "step-cli" - validAudience := []string{"https://test.ca.smallstep.com/revoke"} + validAudience := []string{"https://example.com/revoke"} type authorizeTest struct { auth *Authority token string - opts *RevokeOptions err error + code int } tests := map[string]func(t *testing.T) *authorizeTest{ - "fail/token/invalid-ott": func(t *testing.T) *authorizeTest { + "fail/token/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ - auth: a, - opts: &RevokeOptions{OTT: "foo"}, - err: errors.New("authorizeRevoke: authorizeToken: error parsing token"), + auth: a, + token: "foo", + err: errors.New("authority.authorizeRevoke: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, } }, "fail/token/invalid-subject": func(t *testing.T) *authorizeTest { @@ -296,9 +327,10 @@ func TestAuthority_authorizeRevoke(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: a, - opts: &RevokeOptions{OTT: raw}, - err: errors.New("authorizeRevoke: token subject cannot be empty"), + auth: a, + token: raw, + err: errors.New("authority.authorizeRevoke: jwk.AuthorizeRevoke: jwk.authorizeToken; jwk token subject cannot be empty"), + code: http.StatusUnauthorized, } }, "ok/token": func(t *testing.T) *authorizeTest { @@ -313,34 +345,8 @@ func TestAuthority_authorizeRevoke(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: a, - opts: &RevokeOptions{OTT: raw}, - } - }, - "fail/mTLS/invalid-serial": func(t *testing.T) *authorizeTest { - crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") - assert.FatalError(t, err) - return &authorizeTest{ - auth: a, - opts: &RevokeOptions{MTLS: true, Crt: crt, Serial: "foo"}, - err: errors.New("authorizeRevoke: serial number in certificate different than body"), - } - }, - "fail/mTLS/load-provisioner": func(t *testing.T) *authorizeTest { - crt, err := pemutil.ReadCertificate("./testdata/certs/provisioner-not-found.crt") - assert.FatalError(t, err) - return &authorizeTest{ - auth: a, - opts: &RevokeOptions{MTLS: true, Crt: crt, Serial: "41633491264736369593451462439668497527"}, - err: errors.New("authorizeRevoke: provisioner not found"), - } - }, - "ok/mTLS": func(t *testing.T) *authorizeTest { - crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") - assert.FatalError(t, err) - return &authorizeTest{ - auth: a, - opts: &RevokeOptions{MTLS: true, Crt: crt, Serial: "102012593071130646873265215610956555026"}, + auth: a, + token: raw, } }, } @@ -351,6 +357,9 @@ func TestAuthority_authorizeRevoke(t *testing.T) { if err := tc.auth.authorizeRevoke(context.TODO(), tc.token); err != nil { if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { @@ -360,7 +369,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) { } } -func TestAuthority_AuthorizeSign(t *testing.T) { +func TestAuthority_authorizeSign(t *testing.T) { a := testAuthority(t) jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) @@ -373,20 +382,21 @@ func TestAuthority_AuthorizeSign(t *testing.T) { now := time.Now().UTC() validIssuer := "step-cli" - validAudience := []string{"https://test.ca.smallstep.com/sign"} + validAudience := []string{"https://example.com/sign"} type authorizeTest struct { - auth *Authority - ott string - err *apiError + auth *Authority + token string + err error + code int } tests := map[string]func(t *testing.T) *authorizeTest{ - "fail/invalid-ott": func(t *testing.T) *authorizeTest { + "fail/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ - auth: a, - ott: "foo", - err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"), - http.StatusUnauthorized, apiCtx{"ott": "foo"}}, + auth: a, + token: "foo", + err: errors.New("authority.authorizeSign: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, } }, "fail/invalid-subject": func(t *testing.T) *authorizeTest { @@ -401,10 +411,10 @@ func TestAuthority_AuthorizeSign(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: a, - ott: raw, - err: &apiError{errors.New("authorizeSign: token subject cannot be empty"), - http.StatusUnauthorized, apiCtx{"ott": raw}}, + auth: a, + token: raw, + err: errors.New("authority.authorizeSign: jwk.AuthorizeSign: jwk.authorizeToken; jwk token subject cannot be empty"), + code: http.StatusUnauthorized, } }, "ok": func(t *testing.T) *authorizeTest { @@ -419,8 +429,8 @@ func TestAuthority_AuthorizeSign(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: a, - ott: raw, + auth: a, + token: raw, } }, } @@ -429,18 +439,13 @@ func TestAuthority_AuthorizeSign(t *testing.T) { t.Run(name, func(t *testing.T) { tc := genTestCase(t) - got, err := tc.auth.AuthorizeSign(tc.ott) + got, err := tc.auth.authorizeSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - assert.Nil(t, got) - switch v := err.(type) { - case *apiError: - assert.HasPrefix(t, v.err.Error(), tc.err.Error()) - assert.Equals(t, v.code, tc.err.code) - assert.Equals(t, v.context, tc.err.context) - default: - t.Errorf("unexpected error type: %T", v) - } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { @@ -451,7 +456,6 @@ func TestAuthority_AuthorizeSign(t *testing.T) { } } -// TODO: remove once Authorize deprecated. func TestAuthority_Authorize(t *testing.T) { a := testAuthority(t) @@ -463,22 +467,456 @@ func TestAuthority_Authorize(t *testing.T) { assert.FatalError(t, err) now := time.Now().UTC() - validIssuer := "step-cli" - validAudience := []string{"https://test.ca.smallstep.com/sign"} + + type authorizeTest struct { + auth *Authority + token string + ctx context.Context + err error + code int + } + tests := map[string]func(t *testing.T) *authorizeTest{ + "default-to-signMethod": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + ctx: context.Background(), + err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, + } + }, + "fail/sign/invalid-token": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod), + err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, + } + }, + "ok/sign": func(t *testing.T) *authorizeTest { + cl := jwt.Claims{ + Subject: "test.smallstep.com", + Issuer: validIssuer, + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(time.Minute)), + Audience: testAudiences.Sign, + ID: "1", + } + token, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + assert.FatalError(t, err) + return &authorizeTest{ + auth: a, + token: token, + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod), + } + }, + "fail/revoke/invalid-token": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod), + err: errors.New("authority.Authorize: authority.authorizeRevoke: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, + } + }, + "ok/revoke": func(t *testing.T) *authorizeTest { + cl := jwt.Claims{ + Subject: "test.smallstep.com", + Issuer: validIssuer, + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(time.Minute)), + Audience: testAudiences.Revoke, + ID: "2", + } + token, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + assert.FatalError(t, err) + return &authorizeTest{ + auth: a, + token: token, + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod), + } + }, + "fail/sshSign/invalid-token": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), + err: errors.New("authority.Authorize: authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, + } + }, + "fail/sshSign/disabled": func(t *testing.T) *authorizeTest { + _a := testAuthority(t) + _a.sshCAHostCertSignKey = nil + _a.sshCAUserCertSignKey = nil + return &authorizeTest{ + auth: _a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), + err: errors.New("authority.Authorize; ssh certificate flows are not enabled"), + code: http.StatusNotImplemented, + } + }, + "ok/sshSign": func(t *testing.T) *authorizeTest { + raw, err := generateSimpleSSHUserToken(validIssuer, testAudiences.SSHSign[0], jwk) + assert.FatalError(t, err) + return &authorizeTest{ + auth: a, + token: raw, + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), + } + }, + "fail/sshRenew/invalid-token": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), + err: errors.New("authority.Authorize: authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, + } + }, + "fail/sshRenew/disabled": func(t *testing.T) *authorizeTest { + _a := testAuthority(t) + _a.sshCAHostCertSignKey = nil + _a.sshCAUserCertSignKey = nil + return &authorizeTest{ + auth: _a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), + err: errors.New("authority.Authorize; ssh certificate flows are not enabled"), + code: http.StatusNotImplemented, + } + }, + "ok/sshRenew": func(t *testing.T) *authorizeTest { + key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + + cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + + p, ok := a.provisioners.Load("sshpop/sshpop") + assert.Fatal(t, ok, "sshpop provisioner not found in test authority") + + tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", + []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return &authorizeTest{ + auth: a, + token: tok, + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), + } + }, + "fail/sshRevoke/invalid-token": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), + err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, + } + }, + "ok/sshRevoke": func(t *testing.T) *authorizeTest { + cl := jwt.Claims{ + Subject: "test.smallstep.com", + Issuer: validIssuer, + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(time.Minute)), + Audience: testAudiences.SSHRevoke, + ID: "3", + } + token, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + assert.FatalError(t, err) + return &authorizeTest{ + auth: a, + token: token, + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), + } + }, + "fail/sshRekey/invalid-token": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), + err: errors.New("authority.Authorize: authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, + } + }, + "fail/sshRekey/disabled": func(t *testing.T) *authorizeTest { + _a := testAuthority(t) + _a.sshCAHostCertSignKey = nil + _a.sshCAUserCertSignKey = nil + return &authorizeTest{ + auth: _a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), + err: errors.New("authority.Authorize; ssh certificate flows are not enabled"), + code: http.StatusNotImplemented, + } + }, + "ok/sshRekey": func(t *testing.T) *authorizeTest { + key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + + cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + + p, ok := a.provisioners.Load("sshpop/sshpop") + assert.Fatal(t, ok, "sshpop provisioner not found in test authority") + + tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0]+"#sshpop/sshpop", + []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + + return &authorizeTest{ + auth: a, + token: tok, + ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), + } + }, + "fail/unexpected-method": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + ctx: provisioner.NewContextWithMethod(context.Background(), 15), + err: errors.New("authority.Authorize; method 15 is not supported"), + code: http.StatusInternalServerError, + } + }, + } + + for name, genTestCase := range tests { + t.Run(name, func(t *testing.T) { + tc := genTestCase(t) + got, err := tc.auth.Authorize(tc.ctx, tc.token) + if err != nil { + if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { + assert.Nil(t, got) + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + + ctxErr, ok := err.(*errs.Error) + assert.Fatal(t, ok, "error is not of type *errs.Error") + assert.Equals(t, ctxErr.Details["token"], tc.token) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestAuthority_authorizeRenew(t *testing.T) { + fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") + assert.FatalError(t, err) + + renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") + assert.FatalError(t, err) + + otherCrt, err := pemutil.ReadCertificate("testdata/certs/provisioner-not-found.crt") + assert.FatalError(t, err) type authorizeTest struct { auth *Authority - ott string - err *apiError + cert *x509.Certificate + err error + code int } tests := map[string]func(t *testing.T) *authorizeTest{ - "fail/invalid-ott": func(t *testing.T) *authorizeTest { + "fail/db.IsRevoked-error": func(t *testing.T) *authorizeTest { + a := testAuthority(t) + a.db = &db.MockAuthDB{ + MIsRevoked: func(key string) (bool, error) { + return false, errors.New("force") + }, + } + return &authorizeTest{ auth: a, - ott: "foo", - err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"), - http.StatusUnauthorized, apiCtx{"ott": "foo"}}, + cert: fooCrt, + err: errors.New("authority.authorizeRenew: force"), + code: http.StatusInternalServerError, + } + }, + "fail/revoked": func(t *testing.T) *authorizeTest { + a := testAuthority(t) + a.db = &db.MockAuthDB{ + MIsRevoked: func(key string) (bool, error) { + return true, nil + }, + } + return &authorizeTest{ + auth: a, + cert: fooCrt, + err: errors.New("authority.authorizeRenew: certificate has been revoked"), + code: http.StatusUnauthorized, + } + }, + "fail/load-provisioner": func(t *testing.T) *authorizeTest { + a := testAuthority(t) + a.db = &db.MockAuthDB{ + MIsRevoked: func(key string) (bool, error) { + return false, nil + }, + } + return &authorizeTest{ + auth: a, + cert: otherCrt, + err: errors.New("authority.authorizeRenew: provisioner not found"), + code: http.StatusUnauthorized, + } + }, + "fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest { + a := testAuthority(t) + a.db = &db.MockAuthDB{ + MIsRevoked: func(key string) (bool, error) { + return false, nil + }, + } + + return &authorizeTest{ + auth: a, + cert: renewDisabledCrt, + err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), + code: http.StatusUnauthorized, + } + }, + "ok": func(t *testing.T) *authorizeTest { + a := testAuthority(t) + a.db = &db.MockAuthDB{ + MIsRevoked: func(key string) (bool, error) { + return false, nil + }, + } + return &authorizeTest{ + auth: a, + cert: fooCrt, + } + }, + } + + for name, genTestCase := range tests { + t.Run(name, func(t *testing.T) { + tc := genTestCase(t) + + err := tc.auth.authorizeRenew(tc.cert) + if err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + + ctxErr, ok := err.(*errs.Error) + assert.Fatal(t, ok, "error is not of type *errs.Error") + assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func generateSimpleSSHUserToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { + return generateSSHToken("subject@localhost", iss, aud, time.Now(), &provisioner.SSHOptions{ + CertType: "user", + Principals: []string{"name"}, + }, jwk) +} + +type stepPayload struct { + SSH *provisioner.SSHOptions `json:"ssh,omitempty"` +} + +func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner.SSHOptions, jwk *jose.JSONWebKey) (string, error) { + sig, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), + ) + if err != nil { + return "", err + } + + id, err := randutil.ASCII(64) + if err != nil { + return "", err + } + + claims := struct { + jose.Claims + Step *stepPayload `json:"step,omitempty"` + }{ + Claims: jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + }, + Step: &stepPayload{ + SSH: sshOpts, + }, + } + return jose.Signed(sig).Claims(claims).CompactSerialize() +} + +func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) + if err != nil { + return nil, nil, err + } + cert.Key, err = ssh.NewPublicKey(jwk.Public().Key) + if err != nil { + return nil, nil, err + } + if err = cert.SignCert(rand.Reader, signer); err != nil { + return nil, nil, err + } + return cert, jwk, nil +} + +func TestAuthority_authorizeSSHSign(t *testing.T) { + a := testAuthority(t) + + jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) + assert.FatalError(t, err) + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) + assert.FatalError(t, err) + + now := time.Now().UTC() + + validIssuer := "step-cli" + validAudience := []string{"https://example.com/ssh/sign"} + + type authorizeTest struct { + auth *Authority + token string + err error + code int + } + tests := map[string]func(t *testing.T) *authorizeTest{ + "fail/invalid-token": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + err: errors.New("authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, } }, "fail/invalid-subject": func(t *testing.T) *authorizeTest { @@ -493,26 +931,18 @@ func TestAuthority_Authorize(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ - auth: a, - ott: raw, - err: &apiError{errors.New("authorizeSign: token subject cannot be empty"), - http.StatusUnauthorized, apiCtx{"ott": raw}}, + auth: a, + token: raw, + err: errors.New("authority.authorizeSSHSign: jwk.AuthorizeSSHSign: jwk.authorizeToken; jwk token subject cannot be empty"), + code: http.StatusUnauthorized, } }, "ok": func(t *testing.T) *authorizeTest { - cl := jwt.Claims{ - Subject: "test.smallstep.com", - Issuer: validIssuer, - NotBefore: jwt.NewNumericDate(now), - Expiry: jwt.NewNumericDate(now.Add(time.Minute)), - Audience: validAudience, - ID: "44", - } - raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + raw, err := generateSimpleSSHUserToken(validIssuer, validAudience[0], jwk) assert.FatalError(t, err) return &authorizeTest{ - auth: a, - ott: raw, + auth: a, + token: raw, } }, } @@ -520,113 +950,94 @@ func TestAuthority_Authorize(t *testing.T) { for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) - got, err := tc.auth.Authorize(ctx, tc.ott) + + got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - assert.Nil(t, got) - switch v := err.(type) { - case *apiError: - assert.HasPrefix(t, v.err.Error(), tc.err.Error()) - assert.Equals(t, v.code, tc.err.code) - assert.Equals(t, v.context, tc.err.context) - default: - t.Errorf("unexpected error type: %T", v) - } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 8, got) + assert.Len(t, 11, got) } } }) } } -func TestAuthority_authorizeRenewal(t *testing.T) { - fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") +func TestAuthority_authorizeSSHRenew(t *testing.T) { + a := testAuthority(t) + + jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) - renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) - otherCrt, err := pemutil.ReadCertificate("testdata/certs/provisioner-not-found.crt") - assert.FatalError(t, err) + now := time.Now().UTC() + + validIssuer := "step-cli" type authorizeTest struct { - auth *Authority - crt *x509.Certificate - err *apiError + auth *Authority + token string + cert *ssh.Certificate + err error + code int } tests := map[string]func(t *testing.T) *authorizeTest{ - "fail/db.IsRevoked-error": func(t *testing.T) *authorizeTest { - a := testAuthority(t) - a.db = &MockAuthDB{ - isRevoked: func(key string) (bool, error) { - return false, errors.New("force") - }, - } - + "fail/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ - auth: a, - crt: fooCrt, - err: &apiError{errors.New("renew: force"), - http.StatusInternalServerError, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}}, + auth: a, + token: "foo", + err: errors.New("authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, } }, - "fail/revoked": func(t *testing.T) *authorizeTest { - a := testAuthority(t) - a.db = &MockAuthDB{ - isRevoked: func(key string) (bool, error) { - return true, nil - }, + "fail/sshRenew-unimplemented-jwk-provisioner": func(t *testing.T) *authorizeTest { + cl := jwt.Claims{ + Subject: "", + Issuer: validIssuer, + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(time.Minute)), + Audience: testAudiences.SSHRenew, + ID: "43", } + raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + assert.FatalError(t, err) return &authorizeTest{ - auth: a, - crt: fooCrt, - err: &apiError{errors.New("renew: certificate has been revoked"), - http.StatusUnauthorized, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}}, - } - }, - "fail/load-provisioner": func(t *testing.T) *authorizeTest { - a := testAuthority(t) - a.db = &MockAuthDB{ - isRevoked: func(key string) (bool, error) { - return false, nil - }, - } - return &authorizeTest{ - auth: a, - crt: otherCrt, - err: &apiError{errors.New("renew: provisioner not found"), - http.StatusUnauthorized, apiCtx{"serialNumber": "41633491264736369593451462439668497527"}}, - } - }, - "fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest { - a := testAuthority(t) - a.db = &MockAuthDB{ - isRevoked: func(key string) (bool, error) { - return false, nil - }, - } - - return &authorizeTest{ - auth: a, - crt: renewDisabledCrt, - err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), - http.StatusUnauthorized, apiCtx{"serialNumber": "119772236532068856521070735128919532568"}}, + auth: a, + token: raw, + err: errors.New("authority.authorizeSSHRenew: provisioner.AuthorizeSSHRenew not implemented"), + code: http.StatusUnauthorized, } }, "ok": func(t *testing.T) *authorizeTest { - a := testAuthority(t) - a.db = &MockAuthDB{ - isRevoked: func(key string) (bool, error) { - return false, nil - }, - } + key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + + cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + + p, ok := a.provisioners.Load("sshpop/sshpop") + assert.Fatal(t, ok, "sshpop provisioner not found in test authority") + + tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", + []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return &authorizeTest{ - auth: a, - crt: fooCrt, + auth: a, + token: tok, + cert: cert, } }, } @@ -635,17 +1046,113 @@ func TestAuthority_authorizeRenewal(t *testing.T) { t.Run(name, func(t *testing.T) { tc := genTestCase(t) - err := tc.auth.authorizeRenew(tc.crt) + got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - switch v := err.(type) { - case *apiError: - assert.HasPrefix(t, v.err.Error(), tc.err.Error()) - assert.Equals(t, v.code, tc.err.code) - assert.Equals(t, v.context, tc.err.context) - default: - t.Errorf("unexpected error type: %T", v) - } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.cert.Serial, got.Serial) + } + } + }) + } +} + +func TestAuthority_authorizeSSHRevoke(t *testing.T) { + a := testAuthority(t, []Option{WithDatabase(&db.MockAuthDB{ + MIsSSHRevoked: func(serial string) (bool, error) { + return false, nil + }, + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + })}...) + + jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) + assert.FatalError(t, err) + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) + assert.FatalError(t, err) + + now := time.Now().UTC() + validIssuer := "step-cli" + + type authorizeTest struct { + auth *Authority + token string + cert *ssh.Certificate + err error + code int + } + tests := map[string]func(t *testing.T) *authorizeTest{ + "fail/invalid-token": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + err: errors.New("authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, + } + }, + "fail/invalid-subject": func(t *testing.T) *authorizeTest { + cl := jwt.Claims{ + Subject: "", + Issuer: validIssuer, + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(time.Minute)), + Audience: testAudiences.SSHRevoke, + ID: "43", + } + raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + assert.FatalError(t, err) + return &authorizeTest{ + auth: a, + token: raw, + err: errors.New("authority.authorizeSSHRevoke: jwk.AuthorizeSSHRevoke: jwk.authorizeToken; jwk token subject cannot be empty"), + code: http.StatusUnauthorized, + } + }, + "ok": func(t *testing.T) *authorizeTest { + key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + + cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + + p, ok := a.provisioners.Load("sshpop/sshpop") + assert.Fatal(t, ok, "sshpop provisioner not found in test authority") + + tok, err := generateToken(strconv.FormatUint(cert.Serial, 10), p.GetName(), testAudiences.SSHRevoke[0]+"#sshpop/sshpop", + []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + + return &authorizeTest{ + auth: a, + token: tok, + cert: cert, + } + }, + } + + for name, genTestCase := range tests { + t.Run(name, func(t *testing.T) { + tc := genTestCase(t) + + if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) @@ -653,3 +1160,99 @@ func TestAuthority_authorizeRenewal(t *testing.T) { }) } } + +func TestAuthority_authorizeSSHRekey(t *testing.T) { + a := testAuthority(t) + + jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) + assert.FatalError(t, err) + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) + assert.FatalError(t, err) + + now := time.Now().UTC() + + validIssuer := "step-cli" + + type authorizeTest struct { + auth *Authority + token string + cert *ssh.Certificate + err error + code int + } + tests := map[string]func(t *testing.T) *authorizeTest{ + "fail/invalid-token": func(t *testing.T) *authorizeTest { + return &authorizeTest{ + auth: a, + token: "foo", + err: errors.New("authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), + code: http.StatusUnauthorized, + } + }, + "fail/sshRekey-unimplemented-jwk-provisioner": func(t *testing.T) *authorizeTest { + cl := jwt.Claims{ + Subject: "", + Issuer: validIssuer, + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(time.Minute)), + Audience: testAudiences.SSHRekey, + ID: "43", + } + raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + assert.FatalError(t, err) + return &authorizeTest{ + auth: a, + token: raw, + err: errors.New("authority.authorizeSSHRekey: provisioner.AuthorizeSSHRekey not implemented"), + code: http.StatusUnauthorized, + } + }, + "ok": func(t *testing.T) *authorizeTest { + key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + + cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + + p, ok := a.provisioners.Load("sshpop/sshpop") + assert.Fatal(t, ok, "sshpop provisioner not found in test authority") + + tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0]+"#sshpop/sshpop", + []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + + return &authorizeTest{ + auth: a, + token: tok, + cert: cert, + } + }, + } + + for name, genTestCase := range tests { + t.Run(name, func(t *testing.T) { + tc := genTestCase(t) + + cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token) + if err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.cert.Serial, cert.Serial) + assert.Len(t, 3, signOpts) + } + } + }) + } +} diff --git a/authority/config_test.go b/authority/config_test.go index 40ae639b..c8767dd1 100644 --- a/authority/config_test.go +++ b/authority/config_test.go @@ -1,6 +1,7 @@ package authority import ( + "fmt" "testing" "github.com/pkg/errors" @@ -9,7 +10,6 @@ import ( "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/x509util" stepJOSE "github.com/smallstep/cli/jose" - jose "gopkg.in/square/go-jose.v2" ) func TestConfigValidate(t *testing.T) { @@ -255,28 +255,19 @@ func TestAuthConfigValidate(t *testing.T) { err: errors.New("authority cannot be undefined"), } }, - "fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest { - return AuthConfigValidateTest{ - ac: &AuthConfig{ - Provisioners: provisioner.List{ - &provisioner.JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, - &provisioner.JWK{Name: "foo", Key: &jose.JSONWebKey{}}, + /* + "fail-invalid-claims": func(t *testing.T) AuthConfigValidateTest { + return AuthConfigValidateTest{ + ac: &AuthConfig{ + Provisioners: p, + Claims: &provisioner.Claims{ + MinTLSDur: &provisioner.Duration{Duration: -1}, + }, }, - }, - err: errors.New("provisioner type cannot be empty"), - } - }, - "fail-invalid-claims": func(t *testing.T) AuthConfigValidateTest { - return AuthConfigValidateTest{ - ac: &AuthConfig{ - Provisioners: p, - Claims: &provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: -1}, - }, - }, - err: errors.New("claims: MinTLSCertDuration must be greater than 0"), - } - }, + err: errors.New("claims: MinTLSCertDuration must be greater than 0"), + } + }, + */ "ok-empty-provisioners": func(t *testing.T) AuthConfigValidateTest { return AuthConfigValidateTest{ ac: &AuthConfig{}, @@ -311,7 +302,7 @@ func TestAuthConfigValidate(t *testing.T) { assert.Equals(t, tc.err.Error(), err.Error()) } } else { - if assert.Nil(t, tc.err) { + if assert.Nil(t, tc.err, fmt.Sprintf("expected error: %s, but got ", tc.err)) { assert.Equals(t, *tc.ac.Template, tc.asn1dn) } } diff --git a/authority/db_test.go b/authority/db_test.go deleted file mode 100644 index 72684c63..00000000 --- a/authority/db_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package authority - -import ( - "crypto/x509" - - "github.com/smallstep/certificates/db" - "golang.org/x/crypto/ssh" -) - -type MockAuthDB struct { - err error - ret1 interface{} - isRevoked func(string) (bool, error) - isSSHRevoked func(string) (bool, error) - revoke func(rci *db.RevokedCertificateInfo) error - revokeSSH func(rci *db.RevokedCertificateInfo) error - storeCertificate func(crt *x509.Certificate) error - useToken func(id, tok string) (bool, error) - isSSHHost func(principal string) (bool, error) - storeSSHCertificate func(crt *ssh.Certificate) error - getSSHHostPrincipals func() ([]string, error) - shutdown func() error -} - -func (m *MockAuthDB) IsRevoked(sn string) (bool, error) { - if m.isRevoked != nil { - return m.isRevoked(sn) - } - return m.ret1.(bool), m.err -} - -func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) { - if m.isSSHRevoked != nil { - return m.isSSHRevoked(sn) - } - return m.ret1.(bool), m.err -} - -func (m *MockAuthDB) UseToken(id, tok string) (bool, error) { - if m.useToken != nil { - return m.useToken(id, tok) - } - if m.ret1 == nil { - return false, m.err - } - return m.ret1.(bool), m.err -} - -func (m *MockAuthDB) Revoke(rci *db.RevokedCertificateInfo) error { - if m.revoke != nil { - return m.revoke(rci) - } - return m.err -} - -func (m *MockAuthDB) RevokeSSH(rci *db.RevokedCertificateInfo) error { - if m.revokeSSH != nil { - return m.revokeSSH(rci) - } - return m.err -} - -func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error { - if m.storeCertificate != nil { - return m.storeCertificate(crt) - } - return m.err -} - -func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) { - if m.isSSHHost != nil { - return m.isSSHHost(principal) - } - return m.ret1.(bool), m.err -} - -func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error { - if m.storeSSHCertificate != nil { - return m.storeSSHCertificate(crt) - } - return m.err -} - -func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) { - if m.getSSHHostPrincipals != nil { - return m.getSSHHostPrincipals() - } - return m.ret1.([]string), m.err -} - -func (m *MockAuthDB) Shutdown() error { - if m.shutdown != nil { - return m.shutdown() - } - return m.err -} diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index adba8fd3..7adeb311 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" ) // ACME is the acme provisioner type, an entity that can authorize the ACME @@ -79,7 +80,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // certificate was configured to allow renewals. func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) + return errs.Unauthorized(errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())) } return nil } diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index 2ffdd195..581f20ed 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -3,11 +3,13 @@ package provisioner import ( "context" "crypto/x509" + "net/http" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" ) func TestACME_Getters(t *testing.T) { @@ -88,86 +90,98 @@ func TestACME_Init(t *testing.T) { } } -func TestACME_AuthorizeRevoke(t *testing.T) { - p, err := generateACME() - assert.FatalError(t, err) - assert.Nil(t, p.AuthorizeRevoke(context.TODO(), "")) -} - func TestACME_AuthorizeRenew(t *testing.T) { - p1, err := generateACME() - assert.FatalError(t, err) - p2, err := generateACME() - assert.FatalError(t, err) - - // disable renewal - disable := true - p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) - assert.FatalError(t, err) - - type args struct { + type test struct { + p *ACME cert *x509.Certificate - } - tests := []struct { - name string - prov *ACME - args args err error - }{ - {"ok", p1, args{nil}, nil}, - {"fail", p2, args{nil}, errors.Errorf("renew is disabled for provisioner %s", p2.GetID())}, + code int } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); err != nil { - if assert.NotNil(t, tt.err) { - assert.HasPrefix(t, err.Error(), tt.err.Error()) + tests := map[string]func(*testing.T) test{ + "fail/renew-disabled": func(t *testing.T) test { + p, err := generateACME() + assert.FatalError(t, err) + // disable renewal + disable := true + p.Claims = &Claims{DisableRenewal: &disable} + p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + return test{ + p: p, + cert: &x509.Certificate{}, + code: http.StatusUnauthorized, + err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()), + } + }, + "ok": func(t *testing.T) test { + p, err := generateACME() + assert.FatalError(t, err) + return test{ + p: p, + cert: &x509.Certificate{}, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { - assert.Nil(t, tt.err) + assert.Nil(t, tc.err) } }) } } func TestACME_AuthorizeSign(t *testing.T) { - p1, err := generateACME() - assert.FatalError(t, err) - - tests := []struct { - name string - prov *ACME - method Method - err error - }{ - {"fail/method", p1, SignSSHMethod, errors.New("unexpected method type 1 in context")}, - {"ok", p1, SignMethod, nil}, + type test struct { + p *ACME + token string + code int + err error } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := NewContextWithMethod(context.Background(), tt.method) - if got, err := tt.prov.AuthorizeSign(ctx, ""); err != nil { - if assert.NotNil(t, tt.err) { - assert.HasPrefix(t, err.Error(), tt.err.Error()) + tests := map[string]func(*testing.T) test{ + "ok": func(t *testing.T) test { + p, err := generateACME() + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { - if assert.NotNil(t, got) { - assert.Len(t, 4, got) - - for _, o := range got { + if assert.Nil(t, tc.err) && assert.NotNil(t, opts) { + assert.Len(t, 4, opts) + for _, o := range opts { switch v := o.(type) { case *provisionerExtensionOption: assert.Equals(t, v.Type, int(TypeACME)) - assert.Equals(t, v.Name, tt.prov.GetName()) + assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 74fa3a1f..39769118 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -16,6 +16,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" ) @@ -271,7 +272,7 @@ func (p *AWS) Init(config Config) (err error) { func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { payload, err := p.authorizeToken(token) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign") } doc := payload.document @@ -305,7 +306,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // certificate was configured to allow renewals. func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) + return errs.Unauthorized(errors.Errorf("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())) } return nil } @@ -349,41 +350,41 @@ func (p *AWS) readURL(url string) ([]byte, error) { func (p *AWS) authorizeToken(token string) (*awsPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { - return nil, errors.Wrapf(err, "error parsing token") + return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token") } if len(jwt.Headers) == 0 { - return nil, errors.New("error parsing token: header is missing") + return nil, errs.InternalServerError(errors.New("aws.authorizeToken; error parsing token, header is missing")) } var unsafeClaims awsPayload if err := jwt.UnsafeClaimsWithoutVerification(&unsafeClaims); err != nil { - return nil, errors.Wrap(err, "error unmarshaling claims") + return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling claims") } var payload awsPayload if err := jwt.Claims(unsafeClaims.Amazon.Signature, &payload); err != nil { - return nil, errors.Wrap(err, "error verifying claims") + return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error verifying claims") } // Validate identity document signature if err := p.checkSignature(payload.Amazon.Document, payload.Amazon.Signature); err != nil { - return nil, err + return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token signature") } var doc awsInstanceIdentityDocument if err := json.Unmarshal(payload.Amazon.Document, &doc); err != nil { - return nil, errors.Wrap(err, "error unmarshaling identity document") + return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling aws identity document") } switch { case doc.AccountID == "": - return nil, errors.New("identity document accountId cannot be empty") + return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document accountId cannot be empty")) case doc.InstanceID == "": - return nil, errors.New("identity document instanceId cannot be empty") + return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty")) case doc.PrivateIP == "": - return nil, errors.New("identity document privateIp cannot be empty") + return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty")) case doc.Region == "": - return nil, errors.New("identity document region cannot be empty") + return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document region cannot be empty")) } // According to "rfc7519 JSON Web Token" acceptable skew should be no @@ -393,12 +394,12 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { Issuer: awsIssuer, Time: now, }, time.Minute); err != nil { - return nil, errors.Wrapf(err, "invalid token") + return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token") } // validate audiences with the defaults if !matchesAudience(payload.Audience, p.audiences.Sign) { - return nil, errors.New("invalid token: invalid audience claim (aud)") + return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)")) } // Validate subject, it has to be known if disableCustomSANs is enabled @@ -406,7 +407,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { if payload.Subject != doc.InstanceID && payload.Subject != doc.PrivateIP && payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) { - return nil, errors.New("invalid token: invalid subject claim (sub)") + return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)")) } } @@ -420,14 +421,14 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { } } if !found { - return nil, errors.New("invalid identity document: accountId is not valid") + return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid")) } } // validate instance age if d := p.InstanceAge.Value(); d > 0 { if now.Sub(doc.PendingTime) > d { - return nil, errors.New("identity document pendingTime is too old") + return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document pendingTime is too old")) } } @@ -438,18 +439,18 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) + return nil, errs.Unauthorized(errors.Errorf("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())) } claims, err := p.authorizeToken(token) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSSHSign") } doc := claims.document signOptions := []SignOption{ // set the key id to the token subject - sshCertificateKeyIDModifier(claims.Subject), + sshCertKeyIDModifier(claims.Subject), } // Default to host + known IPs/hostnames @@ -463,7 +464,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user options signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) // Set defaults if not given as user options - signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) + signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) return append(signOptions, // Set the default extensions. diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index e855bf9f..8c59bebe 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -10,12 +10,15 @@ import ( "encoding/hex" "encoding/pem" "fmt" + "net/http" "net/url" "strings" "testing" "time" + "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" ) @@ -229,6 +232,213 @@ func TestAWS_Init(t *testing.T) { } } +func TestAWS_authorizeToken(t *testing.T) { + block, _ := pem.Decode([]byte(awsTestKey)) + if block == nil || block.Type != "RSA PRIVATE KEY" { + t.Fatal("error decoding AWS key") + } + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + assert.FatalError(t, err) + badKey, err := rsa.GenerateKey(rand.Reader, 1024) + assert.FatalError(t, err) + + type test struct { + p *AWS + token string + err error + code int + } + tests := map[string]func(*testing.T) test{ + "fail/bad-token": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; error parsing aws token"), + } + }, + "fail/cannot-validate-sig": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + tok, err := generateAWSToken( + "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), badKey) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; invalid aws token signature"), + } + }, + "fail/empty-account-id": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + tok, err := generateAWSToken( + "instance-id", awsIssuer, p.GetID(), "", "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; aws identity document accountId cannot be empty"), + } + }, + "fail/empty-instance-id": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + tok, err := generateAWSToken( + "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty"), + } + }, + "fail/empty-private-ip": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + tok, err := generateAWSToken( + "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", + "", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty"), + } + }, + "fail/empty-region": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + tok, err := generateAWSToken( + "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", + "127.0.0.1", "", time.Now(), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; aws identity document region cannot be empty"), + } + }, + "fail/invalid-token-issuer": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + tok, err := generateAWSToken( + "instance-id", "bad-issuer", p.GetID(), p.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; invalid aws token"), + } + }, + "fail/invalid-audience": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + tok, err := generateAWSToken( + "instance-id", awsIssuer, "bad-audience", p.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)"), + } + }, + "fail/invalid-subject-disabled-custom-SANs": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + p.DisableCustomSANs = true + tok, err := generateAWSToken( + "foo", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)"), + } + }, + "fail/invalid-account-id": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + tok, err := generateAWSToken( + "instance-id", awsIssuer, p.GetID(), "foo", "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid"), + } + }, + "fail/instance-age": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + p.InstanceAge = Duration{1 * time.Minute} + tok, err := generateAWSToken( + "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now().Add(-1*time.Minute), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("aws.authorizeToken; aws identity document pendingTime is too old"), + } + }, + "ok": func(t *testing.T) test { + p, err := generateAWS() + assert.FatalError(t, err) + tok, err := generateAWSToken( + "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", + "127.0.0.1", "us-west-1", time.Now(), key) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if claims, err := tc.p.authorizeToken(tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) && assert.NotNil(t, claims) { + assert.Equals(t, claims.Subject, "instance-id") + assert.Equals(t, claims.Issuer, awsIssuer) + assert.NotNil(t, claims.Amazon) + + aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID()) + assert.FatalError(t, err) + assert.Equals(t, claims.Audience[0], aud) + } + } + }) + } +} + func TestAWS_AuthorizeSign(t *testing.T) { p1, srv, err := generateAWSWithServer() assert.FatalError(t, err) @@ -326,26 +536,27 @@ func TestAWS_AuthorizeSign(t *testing.T) { aws *AWS args args wantLen int + code int wantErr bool }{ - {"ok", p1, args{t1}, 5, false}, - {"ok", p2, args{t2}, 7, false}, - {"ok", p2, args{t2Hostname}, 7, false}, - {"ok", p2, args{t2PrivateIP}, 7, false}, - {"ok", p1, args{t4}, 5, false}, - {"fail account", p3, args{t3}, 0, true}, - {"fail token", p1, args{"token"}, 0, true}, - {"fail subject", p1, args{failSubject}, 0, true}, - {"fail issuer", p1, args{failIssuer}, 0, true}, - {"fail audience", p1, args{failAudience}, 0, true}, - {"fail account", p1, args{failAccount}, 0, true}, - {"fail instanceID", p1, args{failInstanceID}, 0, true}, - {"fail privateIP", p1, args{failPrivateIP}, 0, true}, - {"fail region", p1, args{failRegion}, 0, true}, - {"fail exp", p1, args{failExp}, 0, true}, - {"fail nbf", p1, args{failNbf}, 0, true}, - {"fail key", p1, args{failKey}, 0, true}, - {"fail instance age", p2, args{failInstanceAge}, 0, true}, + {"ok", p1, args{t1}, 5, http.StatusOK, false}, + {"ok", p2, args{t2}, 7, http.StatusOK, false}, + {"ok", p2, args{t2Hostname}, 7, http.StatusOK, false}, + {"ok", p2, args{t2PrivateIP}, 7, http.StatusOK, false}, + {"ok", p1, args{t4}, 5, http.StatusOK, false}, + {"fail account", p3, args{t3}, 0, http.StatusUnauthorized, true}, + {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, + {"fail subject", p1, args{failSubject}, 0, http.StatusUnauthorized, true}, + {"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true}, + {"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true}, + {"fail account", p1, args{failAccount}, 0, http.StatusUnauthorized, true}, + {"fail instanceID", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true}, + {"fail privateIP", p1, args{failPrivateIP}, 0, http.StatusUnauthorized, true}, + {"fail region", p1, args{failRegion}, 0, http.StatusUnauthorized, true}, + {"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true}, + {"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true}, + {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, + {"fail instance age", p2, args{failInstanceAge}, 0, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -354,8 +565,13 @@ func TestAWS_AuthorizeSign(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) + } else { + assert.Len(t, tt.wantLen, got) } - assert.Len(t, tt.wantLen, got) }) } } @@ -368,6 +584,14 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, err) defer srv.Close() + p2, err := generateAWS() + assert.FatalError(t, err) + // disable sshCA + disable := false + p2.Claims = &Claims{EnableSSHCA: &disable} + p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com") assert.FatalError(t, err) @@ -407,30 +631,35 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { aws *AWS args args expected *SSHOptions + code int wantErr bool wantSignErr bool }{ - {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false}, - {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false}, - {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false}, - {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false}, - {"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, false, false}, - {"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, false, false}, - {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false}, - {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true}, - {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true}, - {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true}, - {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, false, true}, + {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, http.StatusOK, false, false}, + {"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, http.StatusOK, false, false}, + {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true}, + {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true}, + {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, + {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, + {"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, + {"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := NewContextWithMethod(context.Background(), SignSSHMethod) - got, err := tt.aws.AuthorizeSSHSign(ctx, tt.args.token) + got, err := tt.aws.AuthorizeSSHSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) @@ -447,6 +676,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { }) } } + func TestAWS_AuthorizeRenew(t *testing.T) { p1, err := generateAWS() assert.FatalError(t, err) @@ -466,44 +696,20 @@ func TestAWS_AuthorizeRenew(t *testing.T) { name string aws *AWS args args + code int wantErr bool }{ - {"ok", p1, args{nil}, false}, - {"fail", p2, args{nil}, true}, + {"ok", p1, args{nil}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.aws.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestAWS_AuthorizeRevoke(t *testing.T) { - p1, srv, err := generateAWSWithServer() - assert.FatalError(t, err) - defer srv.Close() - - t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com") - assert.FatalError(t, err) - - type args struct { - token string - } - tests := []struct { - name string - aws *AWS - args args - wantErr bool - }{ - {"ok", p1, args{t1}, true}, // revoke is disabled - {"fail", p1, args{"token"}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.aws.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr { - t.Errorf("AWS.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) } }) } diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 998ef6e1..86eb516f 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -13,6 +13,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" ) @@ -209,14 +210,14 @@ func (p *Azure) Init(config Config) (err error) { return nil } -// parseToken returns the claims, name, group, error. -func (p *Azure) parseToken(token string) (*azurePayload, string, string, error) { +// authorizeToken returs the claims, name, group, error. +func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) { jwt, err := jose.ParseSigned(token) if err != nil { - return nil, "", "", errors.Wrapf(err, "error parsing token") + return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token") } if len(jwt.Headers) == 0 { - return nil, "", "", errors.New("error parsing token: header is missing") + return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token missing header")) } var found bool @@ -229,7 +230,7 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error) } } if !found { - return nil, "", "", errors.New("cannot validate token") + return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; cannot validate azure token")) } if err := claims.ValidateWithLeeway(jose.Expected{ @@ -237,17 +238,17 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error) Issuer: p.oidcConfig.Issuer, Time: time.Now(), }, 1*time.Minute); err != nil { - return nil, "", "", errors.Wrap(err, "failed to validate payload") + return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload") } // Validate TenantID if claims.TenantID != p.TenantID { - return nil, "", "", errors.New("validation failed: invalid tenant id claim (tid)") + return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")) } re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) if len(re) != 4 { - return nil, "", "", errors.Errorf("error parsing xms_mirid claim: %s", claims.XMSMirID) + return nil, "", "", errs.Unauthorized(errors.Errorf("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)) } group, name := re[2], re[3] return &claims, name, group, nil @@ -256,9 +257,9 @@ func (p *Azure) parseToken(token string) (*azurePayload, string, string, error) // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - _, name, group, err := p.parseToken(token) + _, name, group, err := p.authorizeToken(token) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign") } // Filter by resource group @@ -271,7 +272,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, } } if !found { - return nil, errors.New("validation failed: invalid resource group") + return nil, errs.Unauthorized(errors.New("azure.AuthorizeSign; azure token validation failed - invalid resource group")) } } @@ -301,7 +302,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // certificate was configured to allow renewals. func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) + return errs.Unauthorized(errors.Errorf("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())) } return nil } @@ -309,16 +310,16 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) + return nil, errs.Unauthorized(errors.Errorf("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())) } - _, name, _, err := p.parseToken(token) + _, name, _, err := p.authorizeToken(token) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign") } signOptions := []SignOption{ // set the key id to the token subject - sshCertificateKeyIDModifier(name), + sshCertKeyIDModifier(name), } // Default to host + known hostnames @@ -329,7 +330,7 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Validate user options signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) // Set defaults if not given as user options - signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) + signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) return append(signOptions, // Set the default extensions. diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 1760ed5c..13e6ac8e 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -15,7 +15,10 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" + "github.com/smallstep/cli/jose" ) func TestAzure_Getters(t *testing.T) { @@ -209,6 +212,148 @@ func TestAzure_Init(t *testing.T) { } } +func TestAzure_authorizeToken(t *testing.T) { + type test struct { + p *Azure + token string + err error + code int + } + tests := map[string]func(*testing.T) test{ + "fail/bad-token": func(t *testing.T) test { + p, err := generateAzure() + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("azure.authorizeToken; error parsing azure token"), + } + }, + "fail/cannot-validate-sig": func(t *testing.T) test { + p, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, + p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now(), jwk) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("azure.authorizeToken; cannot validate azure token"), + } + }, + "fail/invalid-token-issuer": func(t *testing.T) test { + p, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, + p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("azure.authorizeToken; failed to validate azure token payload"), + } + }, + "fail/invalid-tenant-id": func(t *testing.T) test { + p, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, + "foo", "subscriptionID", "resourceGroup", "virtualMachine", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)"), + } + }, + "fail/invalid-xms-mir-id": func(t *testing.T) test { + p, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + jwk := &p.keyStore.keySet.Keys[0] + sig, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), + ) + assert.FatalError(t, err) + + now := time.Now() + claims := azurePayload{ + Claims: jose.Claims{ + Subject: "subject", + Issuer: p.oidcConfig.Issuer, + IssuedAt: jose.NewNumericDate(now), + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + Audience: []string{azureDefaultAudience}, + ID: "the-jti", + }, + AppID: "the-appid", + AppIDAcr: "the-appidacr", + IdentityProvider: "the-idp", + ObjectID: "the-oid", + TenantID: p.TenantID, + Version: "the-version", + XMSMirID: "foo", + } + tok, err := jose.Signed(sig).Claims(claims).CompactSerialize() + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("azure.authorizeToken; error parsing xms_mirid claim - foo"), + } + }, + "ok": func(t *testing.T) test { + p, srv, err := generateAzureWithServer() + assert.FatalError(t, err) + defer srv.Close() + tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, + p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if claims, name, group, err := tc.p.authorizeToken(tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, claims.Subject, "subject") + assert.Equals(t, claims.Issuer, tc.p.oidcConfig.Issuer) + assert.Equals(t, claims.Audience[0], azureDefaultAudience) + + assert.Equals(t, name, "virtualMachine") + assert.Equals(t, group, "resourceGroup") + } + } + }) + } +} + func TestAzure_AuthorizeSign(t *testing.T) { p1, srv, err := generateAzureWithServer() assert.FatalError(t, err) @@ -283,19 +428,20 @@ func TestAzure_AuthorizeSign(t *testing.T) { azure *Azure args args wantLen int + code int wantErr bool }{ - {"ok", p1, args{t1}, 4, false}, - {"ok", p2, args{t2}, 6, false}, - {"ok", p1, args{t11}, 4, false}, - {"fail tenant", p3, args{t3}, 0, true}, - {"fail resource group", p4, args{t4}, 0, true}, - {"fail token", p1, args{"token"}, 0, true}, - {"fail issuer", p1, args{failIssuer}, 0, true}, - {"fail audience", p1, args{failAudience}, 0, true}, - {"fail exp", p1, args{failExp}, 0, true}, - {"fail nbf", p1, args{failNbf}, 0, true}, - {"fail key", p1, args{failKey}, 0, true}, + {"ok", p1, args{t1}, 4, http.StatusOK, false}, + {"ok", p2, args{t2}, 6, http.StatusOK, false}, + {"ok", p1, args{t11}, 4, http.StatusOK, false}, + {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true}, + {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true}, + {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, + {"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true}, + {"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true}, + {"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true}, + {"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true}, + {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -304,8 +450,51 @@ func TestAzure_AuthorizeSign(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) + } else { + assert.Len(t, tt.wantLen, got) + } + }) + } +} + +func TestAzure_AuthorizeRenew(t *testing.T) { + p1, err := generateAzure() + assert.FatalError(t, err) + p2, err := generateAzure() + assert.FatalError(t, err) + + // disable renewal + disable := true + p2.Claims = &Claims{DisableRenewal: &disable} + p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + azure *Azure + args args + code int + wantErr bool + }{ + {"ok", p1, args{nil}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.azure.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) } - assert.Len(t, tt.wantLen, got) }) } } @@ -318,6 +507,14 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, err) defer srv.Close() + p2, err := generateAzure() + assert.FatalError(t, err) + // disable sshCA + disable := false + p2.Claims = &Claims{EnableSSHCA: &disable} + p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + t1, err := p1.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) @@ -349,28 +546,33 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { azure *Azure args args expected *SSHOptions + code int wantErr bool wantSignErr bool }{ - {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false}, - {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false}, - {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false}, - {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false}, - {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false}, - {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true}, - {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true}, - {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true}, - {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, false, true}, + {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true}, + {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true}, + {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, + {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, + {"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, + {"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := NewContextWithMethod(context.Background(), SignSSHMethod) - got, err := tt.azure.AuthorizeSSHSign(ctx, tt.args.token) + got, err := tt.azure.AuthorizeSSHSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) @@ -388,68 +590,6 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { } } -func TestAzure_AuthorizeRenew(t *testing.T) { - p1, err := generateAzure() - assert.FatalError(t, err) - p2, err := generateAzure() - assert.FatalError(t, err) - - // disable renewal - disable := true - p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) - assert.FatalError(t, err) - - type args struct { - cert *x509.Certificate - } - tests := []struct { - name string - azure *Azure - args args - wantErr bool - }{ - {"ok", p1, args{nil}, false}, - {"fail", p2, args{nil}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.azure.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { - t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestAzure_AuthorizeRevoke(t *testing.T) { - az, srv, err := generateAzureWithServer() - assert.FatalError(t, err) - defer srv.Close() - - token, err := az.GetIdentityToken("subject", "caURL") - assert.FatalError(t, err) - - type args struct { - token string - } - tests := []struct { - name string - azure *Azure - args args - wantErr bool - }{ - {"ok token", az, args{token}, true}, // revoke is disabled - {"bad token", az, args{"bad token"}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.azure.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr { - t.Errorf("Azure.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - func TestAzure_assertConfig(t *testing.T) { p1, err := generateAzure() assert.FatalError(t, err) diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index bf189ee5..a1d11740 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -78,7 +78,7 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) // match with server audiences if matchesAudience(claims.Audience, audiences) { - // Use fragment to get provisioner name (GCP, AWS) + // Use fragment to get provisioner name (GCP, AWS, SSHPOP) if fragment != "" { return c.Load(fragment) } diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index bc531e92..69a3006a 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -14,6 +14,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" ) @@ -210,7 +211,7 @@ func (p *GCP) Init(config Config) error { func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign") } ce := claims.Google.ComputeEngine @@ -239,10 +240,10 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er ), nil } -// AuthorizeRenewal returns an error if the renewal is disabled. -func (p *GCP) AuthorizeRenewal(ctx context.Context, cert *x509.Certificate) error { +// AuthorizeRenew returns an error if the renewal is disabled. +func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) + return errs.Unauthorized(errors.Errorf("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())) } return nil } @@ -260,10 +261,10 @@ func (p *GCP) assertConfig() { func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { - return nil, errors.Wrapf(err, "error parsing token") + return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token") } if len(jwt.Headers) == 0 { - return nil, errors.New("error parsing token: header is missing") + return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; error parsing gcp token - header is missing")) } var found bool @@ -277,7 +278,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } } if !found { - return nil, errors.Errorf("failed to validate payload: cannot find key for kid %s", kid) + return nil, errs.Unauthorized(errors.Errorf("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid)) } // According to "rfc7519 JSON Web Token" acceptable skew should be no @@ -287,12 +288,12 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { Issuer: "https://accounts.google.com", Time: now, }, time.Minute); err != nil { - return nil, errors.Wrapf(err, "invalid token") + return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; invalid gcp token payload") } // validate audiences with the defaults if !matchesAudience(claims.Audience, p.audiences.Sign) { - return nil, errors.New("invalid token: invalid audience claim (aud)") + return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")) } // validate subject (service account) @@ -305,7 +306,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } } if !found { - return nil, errors.New("invalid token: invalid subject claim") + return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim")) } } @@ -319,26 +320,26 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } } if !found { - return nil, errors.New("invalid token: invalid project id") + return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid project id")) } } // validate instance age if d := p.InstanceAge.Value(); d > 0 { if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d { - return nil, errors.New("token google.compute_engine.instance_creation_timestamp is too old") + return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old")) } } switch { case claims.Google.ComputeEngine.InstanceID == "": - return nil, errors.New("token google.compute_engine.instance_id cannot be empty") + return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty")) case claims.Google.ComputeEngine.InstanceName == "": - return nil, errors.New("token google.compute_engine.instance_name cannot be empty") + return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty")) case claims.Google.ComputeEngine.ProjectID == "": - return nil, errors.New("token google.compute_engine.project_id cannot be empty") + return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty")) case claims.Google.ComputeEngine.Zone == "": - return nil, errors.New("token google.compute_engine.zone cannot be empty") + return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty")) } return &claims, nil @@ -347,18 +348,18 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) + return nil, errs.Unauthorized(errors.Errorf("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())) } claims, err := p.authorizeToken(token) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSSHSign") } ce := claims.Google.ComputeEngine signOptions := []SignOption{ // set the key id to the token subject - sshCertificateKeyIDModifier(ce.InstanceName), + sshCertKeyIDModifier(ce.InstanceName), } // Default to host + known hostnames @@ -372,7 +373,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user options signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) // Set defaults if not given as user options - signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) + signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) return append(signOptions, // Set the default extensions diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 4764dfc7..bdda8fd9 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -16,7 +16,10 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" + "github.com/smallstep/cli/jose" ) func TestGCP_Getters(t *testing.T) { @@ -211,6 +214,202 @@ func TestGCP_Init(t *testing.T) { } } +func TestGCP_authorizeToken(t *testing.T) { + type test struct { + p *GCP + token string + err error + code int + } + tests := map[string]func(*testing.T) test{ + "fail/bad-token": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; error parsing gcp token"), + } + }, + "fail/cannot-validate-sig": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + tok, err := generateGCPToken(p.ServiceAccounts[0], + "https://accounts.google.com", p.GetID(), + "instance-id", "instance-name", "project-id", "zone", + time.Now(), jwk) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid "), + } + }, + "fail/invalid-issuer": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + tok, err := generateGCPToken(p.ServiceAccounts[0], + "https://foo.bar.zap", p.GetID(), + "instance-id", "instance-name", "project-id", "zone", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; invalid gcp token payload"), + } + }, + "fail/invalid-serviceAccount": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + tok, err := generateGCPToken("foo", + "https://accounts.google.com", p.GetID(), + "instance-id", "instance-name", "project-id", "zone", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim"), + } + }, + "fail/invalid-projectID": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + p.ProjectIDs = []string{"foo", "bar"} + tok, err := generateGCPToken(p.ServiceAccounts[0], + "https://accounts.google.com", p.GetID(), + "instance-id", "instance-name", "project-id", "zone", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; invalid gcp token - invalid project id"), + } + }, + "fail/instance-age": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + p.InstanceAge = Duration{1 * time.Minute} + tok, err := generateGCPToken(p.ServiceAccounts[0], + "https://accounts.google.com", p.GetID(), + "instance-id", "instance-name", "project-id", "zone", + time.Now().Add(-1*time.Minute), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old"), + } + }, + "fail/empty-instance-id": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + tok, err := generateGCPToken(p.ServiceAccounts[0], + "https://accounts.google.com", p.GetID(), + "", "instance-name", "project-id", "zone", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty"), + } + }, + "fail/empty-instance-name": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + tok, err := generateGCPToken(p.ServiceAccounts[0], + "https://accounts.google.com", p.GetID(), + "instance-id", "", "project-id", "zone", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty"), + } + }, + "fail/empty-project-id": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + tok, err := generateGCPToken(p.ServiceAccounts[0], + "https://accounts.google.com", p.GetID(), + "instance-id", "instance-name", "", "zone", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty"), + } + }, + "fail/empty-zone": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + tok, err := generateGCPToken(p.ServiceAccounts[0], + "https://accounts.google.com", p.GetID(), + "instance-id", "instance-name", "project-id", "", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty"), + } + }, + "ok": func(t *testing.T) test { + p, err := generateGCP() + assert.FatalError(t, err) + tok, err := generateGCPToken(p.ServiceAccounts[0], + "https://accounts.google.com", p.GetID(), + "instance-id", "instance-name", "project-id", "zone", + time.Now(), &p.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if claims, err := tc.p.authorizeToken(tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) && assert.NotNil(t, claims) { + assert.Equals(t, claims.Subject, tc.p.ServiceAccounts[0]) + assert.Equals(t, claims.Issuer, "https://accounts.google.com") + assert.NotNil(t, claims.Google) + + aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID()) + assert.FatalError(t, err) + assert.Equals(t, claims.Audience[0], aud) + } + } + }) + } +} + func TestGCP_AuthorizeSign(t *testing.T) { p1, err := generateGCP() assert.FatalError(t, err) @@ -313,24 +512,25 @@ func TestGCP_AuthorizeSign(t *testing.T) { gcp *GCP args args wantLen int + code int wantErr bool }{ - {"ok", p1, args{t1}, 4, false}, - {"ok", p2, args{t2}, 6, false}, - {"ok", p3, args{t3}, 4, false}, - {"fail token", p1, args{"token"}, 0, true}, - {"fail key", p1, args{failKey}, 0, true}, - {"fail iss", p1, args{failIss}, 0, true}, - {"fail aud", p1, args{failAud}, 0, true}, - {"fail exp", p1, args{failExp}, 0, true}, - {"fail nbf", p1, args{failNbf}, 0, true}, - {"fail service account", p1, args{failServiceAccount}, 0, true}, - {"fail invalid project id", p3, args{failInvalidProjectID}, 0, true}, - {"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, true}, - {"fail instance id", p1, args{failInstanceID}, 0, true}, - {"fail instance name", p1, args{failInstanceName}, 0, true}, - {"fail project id", p1, args{failProjectID}, 0, true}, - {"fail zone", p1, args{failZone}, 0, true}, + {"ok", p1, args{t1}, 4, http.StatusOK, false}, + {"ok", p2, args{t2}, 6, http.StatusOK, false}, + {"ok", p3, args{t3}, 4, http.StatusOK, false}, + {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, + {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, + {"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true}, + {"fail aud", p1, args{failAud}, 0, http.StatusUnauthorized, true}, + {"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true}, + {"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true}, + {"fail service account", p1, args{failServiceAccount}, 0, http.StatusUnauthorized, true}, + {"fail invalid project id", p3, args{failInvalidProjectID}, 0, http.StatusUnauthorized, true}, + {"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, http.StatusUnauthorized, true}, + {"fail instance id", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true}, + {"fail instance name", p1, args{failInstanceName}, 0, http.StatusUnauthorized, true}, + {"fail project id", p1, args{failProjectID}, 0, http.StatusUnauthorized, true}, + {"fail zone", p1, args{failZone}, 0, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -339,8 +539,13 @@ func TestGCP_AuthorizeSign(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) + } else { + assert.Len(t, tt.wantLen, got) } - assert.Len(t, tt.wantLen, got) }) } } @@ -352,6 +557,14 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { p1, err := generateGCP() assert.FatalError(t, err) + p2, err := generateGCP() + assert.FatalError(t, err) + // disable sshCA + disable := false + p2.Claims = &Claims{EnableSSHCA: &disable} + p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + t1, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "project-id", "zone", @@ -394,30 +607,35 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { gcp *GCP args args expected *SSHOptions + code int wantErr bool wantSignErr bool }{ - {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false}, - {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false}, - {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false}, - {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false}, - {"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, false, false}, - {"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, false, false}, - {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false}, - {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true}, - {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true}, - {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true}, - {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, false, true}, + {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, http.StatusOK, false, false}, + {"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, http.StatusOK, false, false}, + {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true}, + {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true}, + {"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, + {"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, + {"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, + {"fail-invalid-token", p1, args{"foo", SSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := NewContextWithMethod(context.Background(), SignSSHMethod) - got, err := tt.gcp.AuthorizeSSHSign(ctx, tt.args.token) + got, err := tt.gcp.AuthorizeSSHSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("GCP.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) @@ -435,7 +653,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { } } -func TestGCP_AuthorizeRenewal(t *testing.T) { +func TestGCP_AuthorizeRenew(t *testing.T) { p1, err := generateGCP() assert.FatalError(t, err) p2, err := generateGCP() @@ -454,46 +672,20 @@ func TestGCP_AuthorizeRenewal(t *testing.T) { name string prov *GCP args args + code int wantErr bool }{ - {"ok", p1, args{nil}, false}, - {"fail", p2, args{nil}, true}, + {"ok", p1, args{nil}, http.StatusOK, false}, + {"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.prov.AuthorizeRenewal(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { - t.Errorf("GCP.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestGCP_AuthorizeRevoke(t *testing.T) { - p1, err := generateGCP() - assert.FatalError(t, err) - - t1, err := generateGCPToken(p1.ServiceAccounts[0], - "https://accounts.google.com", p1.GetID(), - "instance-id", "instance-name", "project-id", "zone", - time.Now(), &p1.keyStore.keySet.Keys[0]) - assert.FatalError(t, err) - - type args struct { - token string - } - tests := []struct { - name string - gcp *GCP - args args - wantErr bool - }{ - {"ok", p1, args{t1}, true}, // revoke is disabled - {"fail", p1, args{"token"}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.gcp.AuthorizeRevoke(context.TODO(), tt.args.token); (err != nil) != tt.wantErr { - t.Errorf("GCP.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) + if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) } }) } diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index b5add3f4..1c613de6 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -3,9 +3,11 @@ package provisioner import ( "context" "crypto/x509" + "net/http" "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/jose" ) @@ -99,12 +101,12 @@ func (p *JWK) Init(config Config) (err error) { func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { - return nil, errors.Wrapf(err, "error parsing token") + return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk token") } var claims jwtPayload if err = jwt.Claims(p.Key, &claims); err != nil { - return nil, errors.Wrap(err, "error parsing claims") + return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk claims") } // According to "rfc7519 JSON Web Token" acceptable skew should be no @@ -113,17 +115,17 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err Issuer: p.Name, Time: time.Now().UTC(), }, time.Minute); err != nil { - return nil, errors.Wrapf(err, "invalid token") + return nil, errs.Wrapf(http.StatusUnauthorized, err, "jwk.authorizeToken; invalid jwk claims") } // validate audiences with the defaults if !matchesAudience(claims.Audience, audiences) { - return nil, errors.Errorf("invalid token: invalid audience claim (aud); want %s, but got %s", - audiences, claims.Audience) + return nil, errs.Unauthorized(errors.Errorf("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s", + audiences, claims.Audience)) } if claims.Subject == "" { - return nil, errors.New("token subject cannot be empty") + return nil, errs.Unauthorized(errors.New("jwk.authorizeToken; jwk token subject cannot be empty")) } return &claims, nil @@ -133,14 +135,14 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err // revoke the certificate with serial number in the `sub` property. func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { _, err := p.authorizeToken(token, p.audiences.Revoke) - return err + return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.audiences.Sign) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } // NOTE: This is for backwards compatibility with older versions of cli @@ -171,7 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) + return errs.Unauthorized(errors.Errorf("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())) } return nil } @@ -179,14 +181,14 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) + return nil, errs.Unauthorized(errors.Errorf("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())) } claims, err := p.authorizeToken(token, p.audiences.SSHSign) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") } if claims.Step == nil || claims.Step.SSH == nil { - return nil, errors.New("authorization token must be an SSH provisioning token") + return nil, errs.Unauthorized(errors.New("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token")) } opts := claims.Step.SSH @@ -205,19 +207,19 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, signOptions = append(signOptions, sshCertPrincipalsModifier(opts.Principals)) } if !opts.ValidAfter.IsZero() { - signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix())) + signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix())) } if !opts.ValidBefore.IsZero() { - signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) + signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) } if opts.KeyID != "" { - signOptions = append(signOptions, sshCertificateKeyIDModifier(opts.KeyID)) + signOptions = append(signOptions, sshCertKeyIDModifier(opts.KeyID)) } else { - signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject)) + signOptions = append(signOptions, sshCertKeyIDModifier(claims.Subject)) } // Default to a user certificate with no principals if not set - signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert}) + signOptions = append(signOptions, sshCertDefaultsModifier{CertType: SSHUserCert}) return append(signOptions, // Set the default extensions. @@ -238,5 +240,5 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { _, err := p.authorizeToken(token, p.audiences.SSHRevoke) - return err + return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 47a6e7cc..a0c48ee9 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -7,12 +7,14 @@ import ( "crypto/rsa" "crypto/x509" "net" + "net/http" "strings" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" ) @@ -162,25 +164,29 @@ func TestJWK_authorizeToken(t *testing.T) { name string prov *JWK args args + code int err error }{ - {"fail-token", p1, args{failTok}, errors.New("error parsing token")}, - {"fail-key", p1, args{failKey}, errors.New("error parsing claims")}, - {"fail-claims", p1, args{failClaims}, errors.New("error parsing claims")}, - {"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")}, - {"fail-issuer", p1, args{failIss}, errors.New("invalid token: square/go-jose/jwt: validation failed, invalid issuer claim (iss)")}, - {"fail-expired", p1, args{failExp}, errors.New("invalid token: square/go-jose/jwt: validation failed, token is expired (exp)")}, - {"fail-not-before", p1, args{failNbf}, errors.New("invalid token: square/go-jose/jwt: validation failed, token not valid yet (nbf)")}, - {"fail-audience", p1, args{failAud}, errors.New("invalid token: invalid audience claim (aud)")}, - {"fail-subject", p1, args{failSub}, errors.New("token subject cannot be empty")}, - {"ok", p1, args{t1}, nil}, - {"ok-no-encrypted-key", p2, args{t2}, nil}, - {"ok-no-sans", p1, args{t3}, nil}, + {"fail-token", p1, args{failTok}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk token")}, + {"fail-key", p1, args{failKey}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")}, + {"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")}, + {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")}, + {"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)")}, + {"fail-expired", p1, args{failExp}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, token is expired (exp)")}, + {"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: square/go-jose/jwt: validation failed, token not valid yet (nbf)")}, + {"fail-audience", p1, args{failAud}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk token audience claim (aud)")}, + {"fail-subject", p1, args{failSub}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; jwk token subject cannot be empty")}, + {"ok", p1, args{t1}, http.StatusOK, nil}, + {"ok-no-encrypted-key", p2, args{t2}, http.StatusOK, nil}, + {"ok-no-sans", p1, args{t3}, http.StatusOK, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tt.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { @@ -208,15 +214,19 @@ func TestJWK_AuthorizeRevoke(t *testing.T) { name string prov *JWK args args + code int err error }{ - {"fail-signature", p1, args{failSig}, errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")}, - {"ok", p1, args{t1}, nil}, + {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.AuthorizeRevoke: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")}, + {"ok", p1, args{t1}, http.StatusOK, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token); err != nil { if assert.NotNil(t, tt.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } } @@ -246,20 +256,24 @@ func TestJWK_AuthorizeSign(t *testing.T) { name string prov *JWK args args + code int err error dns []string emails []string ips []net.IP }{ - {name: "fail-signature", prov: p1, args: args{failSig}, err: errors.New("error parsing claims: square/go-jose: error in cryptographic primitive")}, - {"ok-sans", p1, args{t1}, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}}, - {"ok-no-sans", p1, args{t2}, nil, []string{"subject"}, []string{}, []net.IP{}}, + {name: "fail-signature", prov: p1, args: args{failSig}, code: http.StatusUnauthorized, err: errors.New("jwk.AuthorizeSign: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")}, + {"ok-sans", p1, args{t1}, http.StatusOK, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}}, + {"ok-no-sans", p1, args{t2}, http.StatusOK, nil, []string{"subject"}, []string{}, []net.IP{}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil { if assert.NotNil(t, tt.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { @@ -315,15 +329,20 @@ func TestJWK_AuthorizeRenew(t *testing.T) { name string prov *JWK args args + code int wantErr bool }{ - {"ok", p1, args{nil}, false}, - {"fail", p2, args{nil}, true}, + {"ok", p1, args{nil}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) } }) } @@ -335,6 +354,14 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { p1, err := generateJWK() assert.FatalError(t, err) + p2, err := generateJWK() + assert.FatalError(t, err) + // disable sshCA + disable := false + p2.Claims = &Claims{EnableSSHCA: &disable} + p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + jwk, err := decryptJSONWebKey(p1.EncryptedKey) assert.FatalError(t, err) @@ -382,30 +409,34 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { prov *JWK args args expected *SSHOptions + code int wantErr bool wantSignErr bool }{ - {"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false}, - {"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false}, - {"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false}, - {"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false}, - {"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false}, - {"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, false, false}, - {"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false}, - {"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false}, - {"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false}, - {"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, true, false}, - {"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true}, + {"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false}, + {"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false}, + {"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false}, + {"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false}, + {"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false}, + {"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"fail-sshCA-disabled", p2, args{"foo", SSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false}, + {"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false}, + {"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := NewContextWithMethod(context.Background(), SignSSHMethod) - got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token) + got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) @@ -511,10 +542,9 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := NewContextWithMethod(context.Background(), SignSSHMethod) token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk) assert.FatalError(t, err) - if got, err := tt.prov.AuthorizeSSHSign(ctx, token); (err != nil) != tt.wantErr { + if got, err := tt.prov.AuthorizeSSHSign(context.Background(), token); (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) } else if !tt.wantErr && assert.NotNil(t, got) { var opts SSHOptions @@ -535,3 +565,52 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) { }) } } + +func TestJWK_AuthorizeSSHRevoke(t *testing.T) { + type test struct { + p *JWK + token string + code int + err error + } + tests := map[string]func(*testing.T) test{ + "fail/invalid-token": func(t *testing.T) test { + p, err := generateJWK() + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("jwk.AuthorizeSSHRevoke: jwk.authorizeToken; error parsing jwk token"), + } + }, + "ok": func(t *testing.T) test { + p, err := generateJWK() + assert.FatalError(t, err) + jwk, err := decryptJSONWebKey(p.EncryptedKey) + assert.FatalError(t, err) + + tok, err := generateToken("subject", p.Name, testAudiences.SSHRevoke[0], "name@smallstep.com", []string{"127.0.0.1", "max@smallstep.com", "foo"}, time.Now(), jwk) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index e7d45236..0826028e 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -6,8 +6,10 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "net/http" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/jose" "golang.org/x/crypto/ed25519" @@ -138,7 +140,8 @@ func (p *K8sSA) Init(config Config) (err error) { func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { - return nil, errors.Wrapf(err, "error parsing token") + return nil, errs.Wrap(http.StatusUnauthorized, err, + "k8ssa.authorizeToken; error parsing k8sSA token") } var ( @@ -146,7 +149,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, claims k8sSAPayload ) if p.pubKeys == nil { - return nil, errors.New("TokenReview API integration not implemented") + return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented")) /* NOTE: We plan to support the TokenReview API in a future release. Below is some code that should be useful when we prioritize this integration. @@ -174,7 +177,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, } } if !valid { - return nil, errors.New("error validating token and extracting claims") + return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims")) } // According to "rfc7519 JSON Web Token" acceptable skew should be no @@ -182,11 +185,11 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, if err = claims.Validate(jose.Expected{ Issuer: k8sSAIssuer, }); err != nil { - return nil, errors.Wrapf(err, "invalid token claims") + return nil, errs.Wrap(http.StatusUnauthorized, err, "k8ssa.authorizeToken; invalid k8sSA token claims") } if claims.Subject == "" { - return nil, errors.New("token subject cannot be empty") + return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA token subject cannot be empty")) } return &claims, nil @@ -196,14 +199,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, // revoke the certificate with serial number in the `sub` property. func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error { _, err := p.authorizeToken(token, p.audiences.Revoke) - return err + return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - _, err := p.authorizeToken(token, p.audiences.Sign) - if err != nil { - return nil, err + if _, err := p.authorizeToken(token, p.audiences.Sign); err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") } return []SignOption{ @@ -219,7 +221,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // AuthorizeRenew returns an error if the renewal is disabled. func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) + return errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())) } return nil } @@ -227,17 +229,14 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro // AuthorizeSSHSign validates an request for an SSH certificate. func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errors.Errorf("authorizeSSHSign: ssh ca is disabled for provisioner %s", p.GetID()) + return nil, errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())) } - _, err := p.authorizeToken(token, p.audiences.SSHSign) - if err != nil { - return nil, errors.Wrap(err, "authorizeSSHSign") + if _, err := p.authorizeToken(token, p.audiences.SSHSign); err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") } // Default to a user certificate with no principals if not set - signOptions := []SignOption{ - sshCertificateDefaultsModifier{CertType: SSHUserCert}, - } + signOptions := []SignOption{sshCertDefaultsModifier{CertType: SSHUserCert}} return append(signOptions, // Set the default extensions. diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 692e7bab..09a856c5 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -3,11 +3,13 @@ package provisioner import ( "context" "crypto/x509" + "net/http" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" ) @@ -36,6 +38,7 @@ func TestK8sSA_authorizeToken(t *testing.T) { p *K8sSA token string err error + code int } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { @@ -44,7 +47,24 @@ func TestK8sSA_authorizeToken(t *testing.T) { return test{ p: p, token: "foo", - err: errors.New("error parsing token"), + code: http.StatusUnauthorized, + err: errors.New("k8ssa.authorizeToken; error parsing k8sSA token"), + } + }, + "fail/not-implemented": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + p, err := generateK8sSA(nil) + assert.FatalError(t, err) + tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk) + p.pubKeys = nil + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + err: errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented"), + code: http.StatusUnauthorized, } }, "fail/error-validating-token": func(t *testing.T) test { @@ -58,7 +78,8 @@ func TestK8sSA_authorizeToken(t *testing.T) { return test{ p: p, token: tok, - err: errors.New("error validating token and extracting claims"), + err: errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims"), + code: http.StatusUnauthorized, } }, "fail/invalid-issuer": func(t *testing.T) test { @@ -73,7 +94,8 @@ func TestK8sSA_authorizeToken(t *testing.T) { return test{ p: p, token: tok, - err: errors.New("invalid token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"), + code: http.StatusUnauthorized, + err: errors.New("k8ssa.authorizeToken; invalid k8sSA token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"), } }, "ok": func(t *testing.T) test { @@ -94,6 +116,9 @@ func TestK8sSA_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { @@ -105,12 +130,12 @@ func TestK8sSA_authorizeToken(t *testing.T) { } } -func TestK8sSA_AuthorizeSign(t *testing.T) { +func TestK8sSA_AuthorizeRevoke(t *testing.T) { type test struct { p *K8sSA token string - ctx context.Context err error + code int } tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { @@ -119,21 +144,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { return test{ p: p, token: "foo", - err: errors.New("error parsing token"), - } - }, - "fail/ssh-unimplemented": func(t *testing.T) test { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - p, err := generateK8sSA(jwk.Public().Key) - assert.FatalError(t, err) - tok, err := generateK8sSAToken(jwk, nil) - assert.FatalError(t, err) - return test{ - p: p, - ctx: NewContextWithMethod(context.Background(), SignSSHMethod), - token: tok, - err: errors.Errorf("ssh certificates not enabled for k8s ServiceAccount provisioners"), + code: http.StatusUnauthorized, + err: errors.New("k8ssa.AuthorizeRevoke: k8ssa.authorizeToken; error parsing k8sSA token"), } }, "ok": func(t *testing.T) test { @@ -145,7 +157,6 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { assert.FatalError(t, err) return test{ p: p, - ctx: NewContextWithMethod(context.Background(), SignMethod), token: tok, } }, @@ -153,10 +164,110 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if opts, err := tc.p.AuthorizeSign(tc.ctx, tc.token); err != nil { + if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestK8sSA_AuthorizeRenew(t *testing.T) { + type test struct { + p *K8sSA + cert *x509.Certificate + err error + code int + } + tests := map[string]func(*testing.T) test{ + "fail/renew-disabled": func(t *testing.T) test { + p, err := generateK8sSA(nil) + assert.FatalError(t, err) + // disable renewal + disable := true + p.Claims = &Claims{DisableRenewal: &disable} + p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + return test{ + p: p, + cert: &x509.Certificate{}, + code: http.StatusUnauthorized, + err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()), + } + }, + "ok": func(t *testing.T) test { + p, err := generateK8sSA(nil) + assert.FatalError(t, err) + return test{ + p: p, + cert: &x509.Certificate{}, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestK8sSA_AuthorizeSign(t *testing.T) { + type test struct { + p *K8sSA + token string + code int + err error + } + tests := map[string]func(*testing.T) test{ + "fail/invalid-token": func(t *testing.T) test { + p, err := generateK8sSA(nil) + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("k8ssa.AuthorizeSign: k8ssa.authorizeToken; error parsing k8sSA token"), + } + }, + "ok": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + p, err := generateK8sSA(jwk.Public().Key) + assert.FatalError(t, err) + tok, err := generateK8sSAToken(jwk, nil) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { @@ -187,20 +298,37 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { } } -func TestK8sSA_AuthorizeRevoke(t *testing.T) { +func TestK8sSA_AuthorizeSSHSign(t *testing.T) { type test struct { p *K8sSA token string + code int err error } tests := map[string]func(*testing.T) test{ + "fail/sshCA-disabled": func(t *testing.T) test { + p, err := generateK8sSA(nil) + assert.FatalError(t, err) + // disable sshCA + disable := false + p.Claims = &Claims{EnableSSHCA: &disable} + p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()), + } + }, "fail/invalid-token": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ p: p, token: "foo", - err: errors.New("error parsing token"), + code: http.StatusUnauthorized, + err: errors.New("k8ssa.AuthorizeSSHSign: k8ssa.authorizeToken; error parsing k8sSA token"), } }, "ok": func(t *testing.T) test { @@ -219,45 +347,36 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil { + if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestK8sSA_AuthorizeRenew(t *testing.T) { - p1, err := generateK8sSA(nil) - assert.FatalError(t, err) - p2, err := generateK8sSA(nil) - assert.FatalError(t, err) - - // disable renewal - disable := true - p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) - assert.FatalError(t, err) - - type args struct { - cert *x509.Certificate - } - tests := []struct { - name string - prov *K8sSA - args args - wantErr bool - }{ - {"ok", p1, args{nil}, false}, - {"fail", p2, args{nil}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { - t.Errorf("X5C.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + if assert.Nil(t, tc.err) { + if assert.NotNil(t, opts) { + tot := 0 + for _, o := range opts { + switch v := o.(type) { + case sshCertDefaultsModifier: + assert.Equals(t, v.CertType, SSHUserCert) + case *sshDefaultExtensionModifier: + case *sshCertificateValidityValidator: + assert.Equals(t, v.Claimer, tc.p.claimer) + case *sshDefaultPublicKeyValidator: + case *sshCertificateDefaultValidator: + case *sshDefaultDuration: + assert.Equals(t, v.Claimer, tc.p.claimer) + default: + assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + } + tot++ + } + assert.Equals(t, tot, 6) + } + } } }) } diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go index 4e5f32a7..775ed96f 100644 --- a/authority/provisioner/method.go +++ b/authority/provisioner/method.go @@ -16,14 +16,16 @@ const ( SignMethod Method = iota // RevokeMethod is the method used to revoke X.509 certificates. RevokeMethod - // SignSSHMethod is the method used to sign SSH certificates. - SignSSHMethod - // RenewSSHMethod is the method used to renew SSH certificates. - RenewSSHMethod - // RevokeSSHMethod is the method used to revoke SSH certificates. - RevokeSSHMethod - // RekeySSHMethod is the method used to rekey SSH certificates. - RekeySSHMethod + // RenewMethod is the method used to renew X.509 certificates. + RenewMethod + // SSHSignMethod is the method used to sign SSH certificates. + SSHSignMethod + // SSHRenewMethod is the method used to renew SSH certificates. + SSHRenewMethod + // SSHRevokeMethod is the method used to revoke SSH certificates. + SSHRevokeMethod + // SSHRekeyMethod is the method used to rekey SSH certificates. + SSHRekeyMethod ) // String returns a string representation of the context method. @@ -33,14 +35,16 @@ func (m Method) String() string { return "sign-method" case RevokeMethod: return "revoke-method" - case SignSSHMethod: - return "sign-ssh-method" - case RenewSSHMethod: - return "renew-ssh-method" - case RevokeSSHMethod: - return "revoke-ssh-method" - case RekeySSHMethod: - return "rekey-ssh-method" + case RenewMethod: + return "renew-method" + case SSHSignMethod: + return "ssh-sign-method" + case SSHRenewMethod: + return "ssh-renew-method" + case SSHRevokeMethod: + return "ssh-revoke-method" + case SSHRekeyMethod: + return "ssh-rekey-method" default: return "unknown" } diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 4c4b68d2..87710ebb 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -12,6 +12,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" ) @@ -189,17 +190,17 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error { Audience: jose.Audience{o.ClientID}, Time: time.Now().UTC(), }, time.Minute); err != nil { - return errors.Wrap(err, "failed to validate payload") + return errs.Wrap(http.StatusUnauthorized, err, "validatePayload: failed to validate oidc token payload") } // Validate azp if present if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID { - return errors.New("failed to validate payload: invalid azp") + return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: invalid azp")) } // Enforce an email claim if p.Email == "" { - return errors.New("failed to validate payload: email not found") + return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email not found")) } // Validate domains (case-insensitive) @@ -213,7 +214,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error { } } if !found { - return errors.New("failed to validate payload: email is not allowed") + return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email is not allowed")) } } @@ -229,7 +230,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error { } } if !found { - return errors.New("validation failed: invalid group") + return errs.Unauthorized(errors.New("validatePayload: oidc token payload validation failed: invalid group")) } } @@ -241,13 +242,15 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error { func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { - return nil, errors.Wrapf(err, "error parsing token") + return nil, errs.Wrap(http.StatusUnauthorized, err, + "oidc.AuthorizeToken; error parsing oidc token") } // Parse claims to get the kid var claims openIDPayload if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, errors.Wrap(err, "error parsing claims") + return nil, errs.Wrap(http.StatusUnauthorized, err, + "oidc.AuthorizeToken; error parsing oidc token claims") } found := false @@ -260,11 +263,11 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) { } } if !found { - return nil, errors.New("cannot validate token") + return nil, errs.Unauthorized(errors.New("oidc.AuthorizeToken; cannot validate oidc token")) } if err := o.ValidatePayload(claims); err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeToken") } return &claims, nil @@ -276,21 +279,21 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) { func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error { claims, err := o.authorizeToken(token) if err != nil { - return err + return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeRevoke") } // Only admins can revoke certificates. if o.IsAdmin(claims.Email) { return nil } - return errors.New("cannot revoke with non-admin token") + return errs.Unauthorized(errors.New("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")) } // AuthorizeSign validates the given token. func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := o.authorizeToken(token) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSign") } so := []SignOption{ @@ -315,7 +318,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // certificate was configured to allow renewals. func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if o.claimer.IsDisableRenewal() { - return errors.Errorf("renew is disabled for provisioner %s", o.GetID()) + return errs.Unauthorized(errors.Errorf("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())) } return nil } @@ -323,22 +326,22 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !o.claimer.IsSSHCAEnabled() { - return nil, errors.Errorf("ssh ca is disabled for provisioner %s", o.GetID()) + return nil, errs.Unauthorized(errors.Errorf("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())) } claims, err := o.authorizeToken(token) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } signOptions := []SignOption{ // set the key id to the token email - sshCertificateKeyIDModifier(claims.Email), + sshCertKeyIDModifier(claims.Email), } // Get the identity using either the default identityFunc or one injected // externally. iden, err := o.getIdentityFunc(o, claims.Email) if err != nil { - return nil, errors.Wrap(err, "authorizeSSHSign") + return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } defaults := SSHOptions{ CertType: SSHUserCert, @@ -354,7 +357,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Default to a user certificate with usernames as principals if those options // are not set. - signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) + signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) return append(signOptions, // Set the default extensions @@ -374,14 +377,14 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error { claims, err := o.authorizeToken(token) if err != nil { - return err + return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHRevoke") } // Only admins can revoke certificates. - if o.IsAdmin(claims.Email) { - return nil + if !o.IsAdmin(claims.Email) { + return errs.Unauthorized(errors.New("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")) } - return errors.New("cannot revoke with non-admin token") + return nil } func getAndDecode(uri string, v interface{}) error { diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index cbb7b2a2..d0782c1e 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -7,12 +7,14 @@ import ( "crypto/rsa" "crypto/x509" "fmt" + "net/http" "strings" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" ) @@ -206,20 +208,21 @@ func TestOIDC_authorizeToken(t *testing.T) { name string prov *OIDC args args + code int wantErr bool }{ - {"ok1", p1, args{t1}, false}, - {"ok2", p2, args{t2}, false}, - {"fail-email", p3, args{failEmail}, true}, - {"fail-domain", p3, args{failDomain}, true}, - {"fail-key", p1, args{failKey}, true}, - {"fail-token", p1, args{failTok}, true}, - {"fail-claims", p1, args{failClaims}, true}, - {"fail-issuer", p1, args{failIss}, true}, - {"fail-audience", p1, args{failAud}, true}, - {"fail-signature", p1, args{failSig}, true}, - {"fail-expired", p1, args{failExp}, true}, - {"fail-not-before", p1, args{failNbf}, true}, + {"ok1", p1, args{t1}, http.StatusOK, false}, + {"ok2", p2, args{t2}, http.StatusOK, false}, + {"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true}, + {"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, true}, + {"fail-key", p1, args{failKey}, http.StatusUnauthorized, true}, + {"fail-token", p1, args{failTok}, http.StatusUnauthorized, true}, + {"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, true}, + {"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, true}, + {"fail-audience", p1, args{failAud}, http.StatusUnauthorized, true}, + {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, true}, + {"fail-expired", p1, args{failExp}, http.StatusUnauthorized, true}, + {"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -230,6 +233,9 @@ func TestOIDC_authorizeToken(t *testing.T) { return } if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else { assert.NotNil(t, got) @@ -282,21 +288,24 @@ func TestOIDC_AuthorizeSign(t *testing.T) { name string prov *OIDC args args + code int wantErr bool }{ - {"ok1", p1, args{t1}, false}, - {"admin", p3, args{okAdmin}, false}, - {"fail-email", p3, args{failEmail}, true}, + {"ok1", p1, args{t1}, http.StatusOK, false}, + {"admin", p3, args{okAdmin}, http.StatusOK, false}, + {"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := NewContextWithMethod(context.Background(), SignMethod) - got, err := tt.prov.AuthorizeSign(ctx, tt.args.token) + got, err := tt.prov.AuthorizeSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else { if assert.NotNil(t, got) { @@ -330,6 +339,107 @@ func TestOIDC_AuthorizeSign(t *testing.T) { } } +func TestOIDC_AuthorizeRevoke(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + + var keys jose.JSONWebKeySet + assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys)) + + // Create test provisioners + p1, err := generateOIDC() + assert.FatalError(t, err) + p3, err := generateOIDC() + assert.FatalError(t, err) + // Admin + Domains + p3.Admins = []string{"name@smallstep.com", "root@example.com"} + p3.Domains = []string{"smallstep.com"} + + // Update configuration endpoints and initialize + config := Config{Claims: globalProvisionerClaims} + p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + assert.FatalError(t, p1.Init(config)) + assert.FatalError(t, p3.Init(config)) + + t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) + assert.FatalError(t, err) + // Admin email not in domains + okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) + assert.FatalError(t, err) + // Invalid email + failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) + assert.FatalError(t, err) + + type args struct { + token string + } + tests := []struct { + name string + prov *OIDC + args args + code int + wantErr bool + }{ + {"ok1", p1, args{t1}, http.StatusUnauthorized, true}, + {"admin", p3, args{okAdmin}, http.StatusOK, false}, + {"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token) + if (err != nil) != tt.wantErr { + fmt.Println(tt) + t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) + } + }) + } +} + +func TestOIDC_AuthorizeRenew(t *testing.T) { + p1, err := generateOIDC() + assert.FatalError(t, err) + p2, err := generateOIDC() + assert.FatalError(t, err) + + // disable renewal + disable := true + p2.Claims = &Claims{DisableRenewal: &disable} + p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + prov *OIDC + args args + code int + wantErr bool + }{ + {"ok", p1, args{nil}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert) + if (err != nil) != tt.wantErr { + t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) + } + }) + } +} + func TestOIDC_AuthorizeSSHSign(t *testing.T) { tm, fn := mockNow() defer fn() @@ -351,9 +461,16 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, err) p5, err := generateOIDC() assert.FatalError(t, err) + p6, err := generateOIDC() + assert.FatalError(t, err) // Admin + Domains p3.Admins = []string{"name@smallstep.com", "root@example.com"} p3.Domains = []string{"smallstep.com"} + // disable sshCA + disable := false + p6.Claims = &Claims{EnableSSHCA: &disable} + p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) + assert.FatalError(t, err) // Update configuration endpoints and initialize config := Config{Claims: globalProvisionerClaims} @@ -425,48 +542,53 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { prov *OIDC args args expected *SSHOptions + code int wantErr bool wantSignErr bool }{ - {"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false}, - {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false}, - {"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false}, + {"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false}, + {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false}, + {"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false}, {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, &SSHOptions{CertType: "user", Principals: []string{"name"}, - ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"ok-principals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{Principals: []string{"mariano"}}, pub}, &SSHOptions{CertType: "user", Principals: []string{"mariano"}, - ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"ok-emptyPrincipals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{}, pub}, &SSHOptions{CertType: "user", Principals: []string{"max", "mariano"}, - ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, &SSHOptions{CertType: "user", Principals: []string{"name"}, - ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, - {"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, false, false}, - {"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, false, false}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, + {"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, http.StatusOK, false, false}, + {"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, http.StatusOK, false, false}, {"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub}, &SSHOptions{CertType: "user", Principals: []string{"root"}, - ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, &SSHOptions{CertType: "user", Principals: []string{"name"}, - ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, - {"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false}, - {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true}, - {"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, false, true}, - {"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, false, true}, - {"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, true, false}, - {"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, true, false}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, + {"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, + expectedHostOptions, http.StatusOK, false, false}, + {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, + {"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, http.StatusOK, false, true}, + {"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, http.StatusOK, false, true}, + {"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false}, + {"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, http.StatusInternalServerError, true, false}, + {"fail-sshCA-disabled", p6, args{"foo", SSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := NewContextWithMethod(context.Background(), SignSSHMethod) - got, err := tt.prov.AuthorizeSSHSign(ctx, tt.args.token) + got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) @@ -484,36 +606,32 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { } } -func TestOIDC_AuthorizeRevoke(t *testing.T) { +func TestOIDC_AuthorizeSSHRevoke(t *testing.T) { + p1, err := generateOIDC() + assert.FatalError(t, err) + p2, err := generateOIDC() + assert.FatalError(t, err) + p2.Admins = []string{"root@example.com"} + srv := generateJWKServer(2) defer srv.Close() - var keys jose.JSONWebKeySet assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys)) - // Create test provisioners - p1, err := generateOIDC() - assert.FatalError(t, err) - p3, err := generateOIDC() - assert.FatalError(t, err) - // Admin + Domains - p3.Admins = []string{"name@smallstep.com", "root@example.com"} - p3.Domains = []string{"smallstep.com"} - - // Update configuration endpoints and initialize config := Config{Claims: globalProvisionerClaims} p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" - p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" assert.FatalError(t, p1.Init(config)) - assert.FatalError(t, p3.Init(config)) + assert.FatalError(t, p2.Init(config)) - t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) + // Invalid email + failEmail, err := generateToken("subject", "the-issuer", p1.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) // Admin email not in domains - okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) + noAdmin, err := generateToken("subject", "the-issuer", p1.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) - // Invalid email - failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) + // Admin email in domains + okAdmin, err := generateToken("subject", "the-issuer", p2.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) type args struct { @@ -523,52 +641,22 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) { name string prov *OIDC args args + code int wantErr bool }{ - {"ok1", p1, args{t1}, true}, - {"admin", p3, args{okAdmin}, false}, - {"fail-email", p3, args{failEmail}, true}, + {"ok", p2, args{okAdmin}, http.StatusOK, false}, + {"fail/invalid-token", p1, args{failEmail}, http.StatusUnauthorized, true}, + {"fail/not-admin", p1, args{noAdmin}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token) + err := tt.prov.AuthorizeSSHRevoke(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { - fmt.Println(tt) - t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - -func TestOIDC_AuthorizeRenew(t *testing.T) { - p1, err := generateOIDC() - assert.FatalError(t, err) - p2, err := generateOIDC() - assert.FatalError(t, err) - - // disable renewal - disable := true - p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) - assert.FatalError(t, err) - - type args struct { - cert *x509.Certificate - } - tests := []struct { - name string - prov *OIDC - args args - wantErr bool - }{ - {"ok", p1, args{nil}, false}, - {"fail", p2, args{nil}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { - t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr) + } else if err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.code) } }) } diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 4b4200f5..40e1e309 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" "golang.org/x/crypto/ssh" ) @@ -283,43 +284,43 @@ type base struct{} // AuthorizeSign returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for signing x509 Certificates. func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - return nil, errors.New("not implemented; provisioner does not implement AuthorizeSign") + return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSign not implemented")) } // AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for revoking x509 Certificates. func (b *base) AuthorizeRevoke(ctx context.Context, token string) error { - return errors.New("not implemented; provisioner does not implement AuthorizeRevoke") + return errs.Unauthorized(errors.New("provisioner.AuthorizeRevoke not implemented")) } // AuthorizeRenew returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for renewing x509 Certificates. func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - return errors.New("not implemented; provisioner does not implement AuthorizeRenew") + return errs.Unauthorized(errors.New("provisioner.AuthorizeRenew not implemented")) } // AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for signing SSH Certificates. func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - return nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHSign") + return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHSign not implemented")) } // AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for revoking SSH Certificates. func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error { - return errors.New("not implemented; provisioner does not implement AuthorizeSSHRevoke") + return errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRevoke not implemented")) } // AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for renewing SSH Certificates. func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { - return nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRenew") + return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRenew not implemented")) } // AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for rekeying SSH Certificates. func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { - return nil, nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRekey") + return nil, nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRekey not implemented")) } // Identity is the type representing an externally supplied identity that is used diff --git a/authority/provisioner/provisioner_test.go b/authority/provisioner/provisioner_test.go index 14e62769..2577c62f 100644 --- a/authority/provisioner/provisioner_test.go +++ b/authority/provisioner/provisioner_test.go @@ -1,10 +1,14 @@ package provisioner import ( + "context" + "net/http" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" ) func TestType_String(t *testing.T) { @@ -101,3 +105,93 @@ func TestDefaultIdentityFunc(t *testing.T) { }) } } + +func TestUnimplementedMethods(t *testing.T) { + tests := []struct { + name string + p Interface + method Method + }{ + {"jwk/sshRekey", &JWK{}, SSHRekeyMethod}, + {"jwk/sshRenew", &JWK{}, SSHRenewMethod}, + {"aws/revoke", &AWS{}, RevokeMethod}, + {"aws/sshRenew", &AWS{}, SSHRenewMethod}, + {"aws/rekey", &AWS{}, SSHRekeyMethod}, + {"aws/sshRevoke", &AWS{}, SSHRevokeMethod}, + {"azure/revoke", &Azure{}, RevokeMethod}, + {"azure/sshRenew", &Azure{}, SSHRenewMethod}, + {"azure/sshRekey", &Azure{}, SSHRekeyMethod}, + {"azure/sshRevoke", &Azure{}, SSHRevokeMethod}, + {"gcp/revoke", &GCP{}, RevokeMethod}, + {"gcp/sshRenew", &GCP{}, SSHRenewMethod}, + {"gcp/sshRekey", &GCP{}, SSHRekeyMethod}, + {"gcp/sshRevoke", &GCP{}, SSHRevokeMethod}, + {"oidc/sshRenew", &OIDC{}, SSHRenewMethod}, + {"oidc/sshRekey", &OIDC{}, SSHRekeyMethod}, + {"x5c/sshRenew", &X5C{}, SSHRenewMethod}, + {"x5c/sshRekey", &X5C{}, SSHRekeyMethod}, + {"x5c/sshRevoke", &X5C{}, SSHRekeyMethod}, + {"acme/revoke", &ACME{}, RevokeMethod}, + {"acme/sshSign", &ACME{}, SSHSignMethod}, + {"acme/sshRekey", &ACME{}, SSHRekeyMethod}, + {"acme/sshRenew", &ACME{}, SSHRenewMethod}, + {"acme/sshRevoke", &ACME{}, SSHRevokeMethod}, + {"sshpop/sign", &SSHPOP{}, SignMethod}, + {"sshpop/renew", &SSHPOP{}, RenewMethod}, + {"sshpop/revoke", &SSHPOP{}, RevokeMethod}, + {"sshpop/sshSign", &SSHPOP{}, SSHSignMethod}, + {"k8ssa/sshRekey", &K8sSA{}, SSHRekeyMethod}, + {"k8ssa/sshRenew", &K8sSA{}, SSHRenewMethod}, + {"k8ssa/sshRevoke", &K8sSA{}, SSHRevokeMethod}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ( + err error + msg string + ) + + switch tt.method { + case SignMethod: + var signOpts []SignOption + signOpts, err = tt.p.AuthorizeSign(context.Background(), "") + assert.Nil(t, signOpts) + msg = "provisioner.AuthorizeSign not implemented" + case RenewMethod: + err = tt.p.AuthorizeRenew(context.Background(), nil) + msg = "provisioner.AuthorizeRenew not implemented" + case RevokeMethod: + err = tt.p.AuthorizeRevoke(context.Background(), "") + msg = "provisioner.AuthorizeRevoke not implemented" + case SSHSignMethod: + var signOpts []SignOption + signOpts, err = tt.p.AuthorizeSSHSign(context.Background(), "") + assert.Nil(t, signOpts) + msg = "provisioner.AuthorizeSSHSign not implemented" + case SSHRenewMethod: + var cert *ssh.Certificate + cert, err = tt.p.AuthorizeSSHRenew(context.Background(), "") + assert.Nil(t, cert) + msg = "provisioner.AuthorizeSSHRenew not implemented" + case SSHRekeyMethod: + var ( + cert *ssh.Certificate + signOpts []SignOption + ) + cert, signOpts, err = tt.p.AuthorizeSSHRekey(context.Background(), "") + assert.Nil(t, cert) + assert.Nil(t, signOpts) + msg = "provisioner.AuthorizeSSHRekey not implemented" + case SSHRevokeMethod: + err = tt.p.AuthorizeSSHRevoke(context.Background(), "") + msg = "provisioner.AuthorizeSSHRevoke not implemented" + default: + t.Errorf("unexpected method %s", tt.method) + } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized) + assert.Equals(t, err.Error(), msg) + }) + } +} diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 1e6547b7..ed049b6c 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -30,7 +30,7 @@ type SignOption interface{} // CertificateValidator is the interface used to validate a X.509 certificate. type CertificateValidator interface { SignOption - Valid(crt *x509.Certificate) error + Valid(cert *x509.Certificate, o Options) error } // CertificateRequestValidator is the interface used to validate a X.509 @@ -106,7 +106,7 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error { return errors.New("certificate request cannot contain an empty common name") } if req.Subject.CommonName != string(v) { - return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v) + return errors.Errorf("certificate request does not contain the valid common name; requested common name = %s, token subject = %s", req.Subject.CommonName, v) } return nil } @@ -265,33 +265,30 @@ func newValidityValidator(min, max time.Duration) *validityValidator { // Valid validates the certificate validity settings (notBefore/notAfter) and // and total duration. -func (v *validityValidator) Valid(crt *x509.Certificate) error { +func (v *validityValidator) Valid(cert *x509.Certificate, o Options) error { var ( - na = crt.NotAfter.Truncate(time.Second) - nb = crt.NotBefore.Truncate(time.Second) + na = cert.NotAfter.Truncate(time.Second) + nb = cert.NotBefore.Truncate(time.Second) now = time.Now().Truncate(time.Second) ) - // To not take into account the backdate, time.Now() will be used to - // calculate the duration if NotBefore is in the past. - var d time.Duration - if now.After(nb) { - d = na.Sub(now) - } else { - d = na.Sub(nb) - } + d := na.Sub(nb) if na.Before(now) { - return errors.Errorf("NotAfter: %v cannot be in the past", na) + return errors.Errorf("notAfter cannot be in the past; na=%v", na) } if na.Before(nb) { - return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb) + return errors.Errorf("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) } if d < v.min { return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) } - if d > v.max { + // NOTE: this check is not "technically correct". We're allowing the max + // duration of a cert to be "max + backdate" and not all certificates will + // be backdated (e.g. if a user passes the NotBefore value then we do not + // apply a backdate). This is good enough. + if d > v.max+o.Backdate { return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max) } diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index 1076d3b5..74c8d1f4 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -3,9 +3,10 @@ package provisioner import ( "crypto/x509" "crypto/x509/pkix" + "fmt" "net" "net/url" - "reflect" + "strings" "testing" "time" @@ -48,22 +49,22 @@ func Test_emailOnlyIdentity_Valid(t *testing.T) { } func Test_defaultPublicKeyValidator_Valid(t *testing.T) { - _shortRSA, err := pemutil.Read("./testdata/short-rsa.csr") + _shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr") assert.FatalError(t, err) shortRSA, ok := _shortRSA.(*x509.CertificateRequest) assert.Fatal(t, ok) - _rsa, err := pemutil.Read("./testdata/rsa.csr") + _rsa, err := pemutil.Read("./testdata/certs/rsa.csr") assert.FatalError(t, err) rsaCSR, ok := _rsa.(*x509.CertificateRequest) assert.Fatal(t, ok) - _ecdsa, err := pemutil.Read("./testdata/ecdsa.csr") + _ecdsa, err := pemutil.Read("./testdata/certs/ecdsa.csr") assert.FatalError(t, err) ecdsaCSR, ok := _ecdsa.(*x509.CertificateRequest) assert.Fatal(t, ok) - _ed25519, err := pemutil.Read("./testdata/ed25519.csr") + _ed25519, err := pemutil.Read("./testdata/certs/ed25519.csr") assert.FatalError(t, err) ed25519CSR, ok := _ed25519.(*x509.CertificateRequest) assert.Fatal(t, ok) @@ -246,30 +247,191 @@ func Test_ipAddressesValidator_Valid(t *testing.T) { } func Test_validityValidator_Valid(t *testing.T) { - type fields struct { - min time.Duration - max time.Duration + type test struct { + cert *x509.Certificate + opts Options + vv *validityValidator + err error } - type args struct { - crt *x509.Certificate - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - v := &validityValidator{ - min: tt.fields.min, - max: tt.fields.max, + tests := map[string]func() test{ + "fail/notAfter-past": func() test { + return test{ + vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, + cert: &x509.Certificate{NotAfter: time.Now().Add(-5 * time.Minute)}, + opts: Options{}, + err: errors.New("notAfter cannot be in the past"), } - if err := v.Valid(tt.args.crt); (err != nil) != tt.wantErr { - t.Errorf("validityValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) + }, + "fail/notBefore-after-notAfter": func() test { + return test{ + vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, + cert: &x509.Certificate{NotBefore: time.Now().Add(10 * time.Minute), + NotAfter: time.Now().Add(5 * time.Minute)}, + opts: Options{}, + err: errors.New("notAfter cannot be before notBefore"), } + }, + "fail/duration-too-short": func() test { + n := now() + return test{ + vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, + cert: &x509.Certificate{NotBefore: n, + NotAfter: n.Add(3 * time.Minute)}, + opts: Options{}, + err: errors.New("is less than the authorized minimum certificate duration of "), + } + }, + "ok/duration-exactly-min": func() test { + n := now() + return test{ + vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, + cert: &x509.Certificate{NotBefore: n, + NotAfter: n.Add(5 * time.Minute)}, + opts: Options{}, + } + }, + "fail/duration-too-great": func() test { + n := now() + return test{ + vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, + cert: &x509.Certificate{NotBefore: n, + NotAfter: n.Add(24*time.Hour + time.Second)}, + err: errors.New("is more than the authorized maximum certificate duration of "), + } + }, + "ok/duration-exactly-max": func() test { + n := time.Now() + return test{ + vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, + cert: &x509.Certificate{NotBefore: n, + NotAfter: n.Add(24 * time.Hour)}, + } + }, + "ok/duration-exact-min-with-backdate": func() test { + now := time.Now() + cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(5 * time.Minute)} + time.Sleep(time.Second) + return test{ + vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, + cert: cert, + opts: Options{Backdate: time.Second}, + } + }, + "ok/duration-exact-max-with-backdate": func() test { + backdate := time.Second + now := time.Now() + cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(24*time.Hour + backdate)} + time.Sleep(backdate) + return test{ + vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, + cert: cert, + opts: Options{Backdate: backdate}, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tt := run() + if err := tt.vv.Valid(tt.cert, tt.opts); err != nil { + if assert.NotNil(t, tt.err, fmt.Sprintf("expected no error, but got err = %s", err.Error())) { + assert.True(t, strings.Contains(err.Error(), tt.err.Error()), + fmt.Sprintf("want err = %s, but got err = %s", tt.err.Error(), err.Error())) + } + } else { + assert.Nil(t, tt.err, fmt.Sprintf("expected err = %s, but not ", tt.err)) + } + }) + } +} + +func Test_profileDefaultDuration_Option(t *testing.T) { + type test struct { + so Options + pdd profileDefaultDuration + cert *x509.Certificate + valid func(*x509.Certificate) + } + tests := map[string]func() test{ + "ok/notBefore-notAfter-duration-empty": func() test { + return test{ + pdd: profileDefaultDuration(0), + so: Options{}, + cert: new(x509.Certificate), + valid: func(cert *x509.Certificate) { + n := now() + assert.True(t, n.After(cert.NotBefore)) + assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore)) + + assert.True(t, n.Add(24*time.Hour).After(cert.NotAfter)) + assert.True(t, n.Add(24*time.Hour).Add(-1*time.Minute).Before(cert.NotAfter)) + }, + } + }, + "ok/notBefore-set": func() test { + nb := time.Now().Add(5 * time.Minute).UTC() + return test{ + pdd: profileDefaultDuration(0), + so: Options{NotBefore: NewTimeDuration(nb)}, + cert: new(x509.Certificate), + valid: func(cert *x509.Certificate) { + assert.Equals(t, cert.NotBefore, nb) + assert.Equals(t, cert.NotAfter, nb.Add(24*time.Hour)) + }, + } + }, + "ok/duration-set": func() test { + d := 4 * time.Hour + return test{ + pdd: profileDefaultDuration(d), + so: Options{Backdate: time.Second}, + cert: new(x509.Certificate), + valid: func(cert *x509.Certificate) { + n := now() + assert.True(t, n.After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n, cert.NotBefore)) + assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore)) + + assert.True(t, n.Add(d).After(cert.NotAfter)) + assert.True(t, n.Add(d).Add(-1*time.Minute).Before(cert.NotAfter)) + }, + } + }, + "ok/notAfter-set": func() test { + na := now().Add(10 * time.Minute).UTC() + return test{ + pdd: profileDefaultDuration(0), + so: Options{NotAfter: NewTimeDuration(na)}, + cert: new(x509.Certificate), + valid: func(cert *x509.Certificate) { + n := now() + assert.True(t, n.After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n, cert.NotBefore)) + assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore)) + + assert.Equals(t, cert.NotAfter, na) + }, + } + }, + "ok/notBefore-and-notAfter-set": func() test { + nb := time.Now().Add(5 * time.Minute).UTC() + na := time.Now().Add(10 * time.Minute).UTC() + d := 4 * time.Hour + return test{ + pdd: profileDefaultDuration(d), + so: Options{NotBefore: NewTimeDuration(nb), NotAfter: NewTimeDuration(na)}, + cert: new(x509.Certificate), + valid: func(cert *x509.Certificate) { + assert.Equals(t, cert.NotBefore, nb) + assert.Equals(t, cert.NotAfter, na) + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tt := run() + prof := &x509util.Leaf{} + prof.SetSubject(tt.cert) + assert.FatalError(t, tt.pdd.Option(tt.so)(prof), "unexpected error") + tt.valid(prof.Subject()) }) } } @@ -381,43 +543,3 @@ func Test_profileLimitDuration_Option(t *testing.T) { }) } } - -func Test_profileDefaultDuration_Option(t *testing.T) { - tm, fn := mockNow() - defer fn() - - v := profileDefaultDuration(24 * time.Hour) - type args struct { - so Options - } - tests := []struct { - name string - v profileDefaultDuration - args args - want *x509.Certificate - }{ - {"default", v, args{Options{}}, &x509.Certificate{NotBefore: tm, NotAfter: tm.Add(24 * time.Hour)}}, - {"backdate", v, args{Options{Backdate: 1 * time.Minute}}, &x509.Certificate{NotBefore: tm.Add(-1 * time.Minute), NotAfter: tm.Add(24 * time.Hour)}}, - {"notBefore", v, args{Options{NotBefore: NewTimeDuration(tm.Add(10 * time.Second))}}, &x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(24*time.Hour + 10*time.Second)}}, - {"notAfter", v, args{Options{NotAfter: NewTimeDuration(tm.Add(1 * time.Hour))}}, &x509.Certificate{NotBefore: tm, NotAfter: tm.Add(1 * time.Hour)}}, - {"notBefore and notAfter", v, args{Options{NotBefore: NewTimeDuration(tm.Add(10 * time.Second)), NotAfter: NewTimeDuration(tm.Add(1 * time.Hour))}}, - &x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(1 * time.Hour)}}, - {"notBefore and backdate", v, args{Options{Backdate: 1 * time.Minute, NotBefore: NewTimeDuration(tm.Add(10 * time.Second))}}, - &x509.Certificate{NotBefore: tm.Add(10 * time.Second), NotAfter: tm.Add(24*time.Hour + 10*time.Second)}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cert := &x509.Certificate{} - profile := &x509util.Leaf{} - profile.SetSubject(cert) - - fn := tt.v.Option(tt.args.so) - if err := fn(profile); err != nil { - t.Errorf("profileDefaultDuration.Option() error = %v", err) - } - if !reflect.DeepEqual(cert, tt.want) { - t.Errorf("profileDefaultDuration.Option() = %v, \nwant %v", cert, tt.want) - } - }) - } -} diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 643e0645..ec67baf1 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -78,7 +78,7 @@ func (o SSHOptions) Modify(cert *ssh.Certificate) error { case SSHHostCert: cert.CertType = ssh.HostCert default: - return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType) + return errors.Errorf("ssh certificate has an unknown type - %s", o.CertType) } cert.KeyId = o.KeyID @@ -126,11 +126,11 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error { return nil } -// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given +// sshCertKeyIDModifier is an SSHCertificateModifier that sets the given // Key ID in the SSH certificate. -type sshCertificateKeyIDModifier string +type sshCertKeyIDModifier string -func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error { +func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error { cert.KeyId = string(m) return nil } @@ -145,30 +145,30 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error { return nil } -// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the +// sshCertValidAfterModifier is an SSHCertificateModifier that sets the // ValidAfter in the SSH certificate. -type sshCertificateValidAfterModifier uint64 +type sshCertValidAfterModifier uint64 -func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error { +func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error { cert.ValidAfter = uint64(m) return nil } -// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the +// sshCertValidBeforeModifier is an SSHCertificateModifier that sets the // ValidBefore in the SSH certificate. -type sshCertificateValidBeforeModifier uint64 +type sshCertValidBeforeModifier uint64 -func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error { +func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error { cert.ValidBefore = uint64(m) return nil } -// sshCertificateDefaultModifier implements a SSHCertificateModifier that +// sshCertDefaultsModifier implements a SSHCertificateModifier that // modifies the certificate with the given options if they are not set. -type sshCertificateDefaultsModifier SSHOptions +type sshCertDefaultsModifier SSHOptions // Modify implements the SSHCertificateModifier interface. -func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error { +func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error { if cert.CertType == 0 { cert.CertType = sshCertTypeUInt32(m.CertType) } diff --git a/authority/provisioner/sign_ssh_options_test.go b/authority/provisioner/sign_ssh_options_test.go index e447065b..87716e37 100644 --- a/authority/provisioner/sign_ssh_options_test.go +++ b/authority/provisioner/sign_ssh_options_test.go @@ -38,6 +38,457 @@ func TestSSHOptions_Type(t *testing.T) { } } +func TestSSHOptions_Modify(t *testing.T) { + type test struct { + so *SSHOptions + cert *ssh.Certificate + valid func(*ssh.Certificate) + err error + } + tests := map[string](func() test){ + "fail/unexpected-cert-type": func() test { + return test{ + so: &SSHOptions{CertType: "foo"}, + cert: new(ssh.Certificate), + err: errors.Errorf("ssh certificate has an unknown type - foo"), + } + }, + "fail/validAfter-greater-validBefore": func() test { + return test{ + so: &SSHOptions{CertType: "user"}, + cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)}, + err: errors.Errorf("ssh certificate valid after cannot be greater than valid before"), + } + }, + "ok/user-cert": func() test { + return test{ + so: &SSHOptions{CertType: "user"}, + cert: new(ssh.Certificate), + valid: func(cert *ssh.Certificate) { + assert.Equals(t, cert.CertType, uint32(ssh.UserCert)) + }, + } + }, + "ok/host-cert": func() test { + return test{ + so: &SSHOptions{CertType: "host"}, + cert: new(ssh.Certificate), + valid: func(cert *ssh.Certificate) { + assert.Equals(t, cert.CertType, uint32(ssh.HostCert)) + }, + } + }, + "ok": func() test { + va := time.Now().Add(5 * time.Minute) + vb := time.Now().Add(1 * time.Hour) + so := &SSHOptions{CertType: "host", KeyID: "foo", Principals: []string{"foo", "bar"}, + ValidAfter: NewTimeDuration(va), ValidBefore: NewTimeDuration(vb)} + return test{ + so: so, + cert: new(ssh.Certificate), + valid: func(cert *ssh.Certificate) { + assert.Equals(t, cert.CertType, uint32(ssh.HostCert)) + assert.Equals(t, cert.KeyId, so.KeyID) + assert.Equals(t, cert.ValidPrincipals, so.Principals) + assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix())) + assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix())) + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run() + if err := tc.so.Modify(tc.cert); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + tc.valid(tc.cert) + } + } + }) + } +} + +func TestSSHOptions_Match(t *testing.T) { + type test struct { + so SSHOptions + cmp SSHOptions + err error + } + tests := map[string](func() test){ + "fail/cert-type": func() test { + return test{ + so: SSHOptions{CertType: "foo"}, + cmp: SSHOptions{CertType: "bar"}, + err: errors.Errorf("ssh certificate type does not match - got bar, want foo"), + } + }, + "fail/pricipals": func() test { + return test{ + so: SSHOptions{Principals: []string{"foo"}}, + cmp: SSHOptions{Principals: []string{"bar"}}, + err: errors.Errorf("ssh certificate principals does not match - got [bar], want [foo]"), + } + }, + "fail/validAfter": func() test { + return test{ + so: SSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))}, + cmp: SSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))}, + err: errors.Errorf("ssh certificate valid after does not match"), + } + }, + "fail/validBefore": func() test { + return test{ + so: SSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))}, + cmp: SSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))}, + err: errors.Errorf("ssh certificate valid before does not match"), + } + }, + "ok/original-empty": func() test { + return test{ + so: SSHOptions{}, + cmp: SSHOptions{ + CertType: "foo", + Principals: []string{"foo"}, + ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)), + ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)), + }, + } + }, + "ok/cmp-empty": func() test { + return test{ + cmp: SSHOptions{}, + so: SSHOptions{ + CertType: "foo", + Principals: []string{"foo"}, + ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)), + ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)), + }, + } + }, + "ok/equal": func() test { + n := time.Now() + va := NewTimeDuration(n.Add(1 * time.Minute)) + vb := NewTimeDuration(n.Add(5 * time.Minute)) + return test{ + cmp: SSHOptions{ + CertType: "foo", + Principals: []string{"foo"}, + ValidAfter: va, + ValidBefore: vb, + }, + so: SSHOptions{ + CertType: "foo", + Principals: []string{"foo"}, + ValidAfter: va, + ValidBefore: vb, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run() + if err := tc.so.match(tc.cmp); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func Test_sshCertPrincipalsModifier_Modify(t *testing.T) { + type test struct { + modifier sshCertPrincipalsModifier + cert *ssh.Certificate + expected []string + } + tests := map[string](func() test){ + "ok": func() test { + a := []string{"foo", "bar"} + return test{ + modifier: sshCertPrincipalsModifier(a), + cert: new(ssh.Certificate), + expected: a, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run() + if assert.Nil(t, tc.modifier.Modify(tc.cert)) { + assert.Equals(t, tc.cert.ValidPrincipals, tc.expected) + } + }) + } +} + +func Test_sshCertKeyIDModifier_Modify(t *testing.T) { + type test struct { + modifier sshCertKeyIDModifier + cert *ssh.Certificate + expected string + } + tests := map[string](func() test){ + "ok": func() test { + a := "foo" + return test{ + modifier: sshCertKeyIDModifier(a), + cert: new(ssh.Certificate), + expected: a, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run() + if assert.Nil(t, tc.modifier.Modify(tc.cert)) { + assert.Equals(t, tc.cert.KeyId, tc.expected) + } + }) + } +} + +func Test_sshCertTypeModifier_Modify(t *testing.T) { + type test struct { + modifier sshCertTypeModifier + cert *ssh.Certificate + expected uint32 + } + tests := map[string](func() test){ + "ok/user": func() test { + return test{ + modifier: sshCertTypeModifier("user"), + cert: new(ssh.Certificate), + expected: ssh.UserCert, + } + }, + "ok/host": func() test { + return test{ + modifier: sshCertTypeModifier("host"), + cert: new(ssh.Certificate), + expected: ssh.HostCert, + } + }, + "ok/default": func() test { + return test{ + modifier: sshCertTypeModifier("foo"), + cert: new(ssh.Certificate), + expected: 0, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run() + if assert.Nil(t, tc.modifier.Modify(tc.cert)) { + assert.Equals(t, tc.cert.CertType, uint32(tc.expected)) + } + }) + } +} + +func Test_sshCertValidAfterModifier_Modify(t *testing.T) { + type test struct { + modifier sshCertValidAfterModifier + cert *ssh.Certificate + expected uint64 + } + tests := map[string](func() test){ + "ok": func() test { + return test{ + modifier: sshCertValidAfterModifier(15), + cert: new(ssh.Certificate), + expected: 15, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run() + if assert.Nil(t, tc.modifier.Modify(tc.cert)) { + assert.Equals(t, tc.cert.ValidAfter, tc.expected) + } + }) + } +} + +func Test_sshCertDefaultsModifier_Modify(t *testing.T) { + type test struct { + modifier sshCertDefaultsModifier + cert *ssh.Certificate + valid func(*ssh.Certificate) + } + tests := map[string](func() test){ + "ok/changes": func() test { + n := time.Now() + va := NewTimeDuration(n.Add(1 * time.Minute)) + vb := NewTimeDuration(n.Add(5 * time.Minute)) + so := SSHOptions{ + Principals: []string{"foo", "bar"}, + CertType: "host", + ValidAfter: va, + ValidBefore: vb, + } + return test{ + modifier: sshCertDefaultsModifier(so), + cert: new(ssh.Certificate), + valid: func(cert *ssh.Certificate) { + assert.Equals(t, cert.ValidPrincipals, so.Principals) + assert.Equals(t, cert.CertType, uint32(ssh.HostCert)) + assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix())) + assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix())) + }, + } + }, + "ok/no-changes": func() test { + n := time.Now() + so := SSHOptions{ + Principals: []string{"foo", "bar"}, + CertType: "host", + ValidAfter: NewTimeDuration(n.Add(15 * time.Minute)), + ValidBefore: NewTimeDuration(n.Add(25 * time.Minute)), + } + return test{ + modifier: sshCertDefaultsModifier(so), + cert: &ssh.Certificate{ + CertType: uint32(ssh.UserCert), + ValidPrincipals: []string{"zap", "zoop"}, + ValidAfter: 15, + ValidBefore: 25, + }, + valid: func(cert *ssh.Certificate) { + assert.Equals(t, cert.ValidPrincipals, []string{"zap", "zoop"}) + assert.Equals(t, cert.CertType, uint32(ssh.UserCert)) + assert.Equals(t, cert.ValidAfter, uint64(15)) + assert.Equals(t, cert.ValidBefore, uint64(25)) + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run() + if assert.Nil(t, tc.modifier.Modify(tc.cert)) { + tc.valid(tc.cert) + } + }) + } +} + +func Test_sshDefaultExtensionModifier_Modify(t *testing.T) { + type test struct { + modifier sshDefaultExtensionModifier + cert *ssh.Certificate + valid func(*ssh.Certificate) + err error + } + tests := map[string](func() test){ + "fail/unexpected-cert-type": func() test { + cert := &ssh.Certificate{CertType: 3} + return test{ + modifier: sshDefaultExtensionModifier{}, + cert: cert, + err: errors.New("ssh certificate type has not been set or is invalid"), + } + }, + "ok/host": func() test { + cert := &ssh.Certificate{CertType: ssh.HostCert} + return test{ + modifier: sshDefaultExtensionModifier{}, + cert: cert, + valid: func(cert *ssh.Certificate) { + assert.Len(t, 0, cert.Extensions) + }, + } + }, + "ok/user/extensions-exists": func() test { + cert := &ssh.Certificate{CertType: ssh.UserCert, Permissions: ssh.Permissions{Extensions: map[string]string{ + "foo": "bar", + }}} + return test{ + modifier: sshDefaultExtensionModifier{}, + cert: cert, + valid: func(cert *ssh.Certificate) { + val, ok := cert.Extensions["foo"] + assert.True(t, ok) + assert.Equals(t, val, "bar") + + val, ok = cert.Extensions["permit-X11-forwarding"] + assert.True(t, ok) + assert.Equals(t, val, "") + + val, ok = cert.Extensions["permit-agent-forwarding"] + assert.True(t, ok) + assert.Equals(t, val, "") + + val, ok = cert.Extensions["permit-port-forwarding"] + assert.True(t, ok) + assert.Equals(t, val, "") + + val, ok = cert.Extensions["permit-pty"] + assert.True(t, ok) + assert.Equals(t, val, "") + + val, ok = cert.Extensions["permit-user-rc"] + assert.True(t, ok) + assert.Equals(t, val, "") + }, + } + }, + "ok/user/no-extensions": func() test { + return test{ + modifier: sshDefaultExtensionModifier{}, + cert: &ssh.Certificate{CertType: ssh.UserCert}, + valid: func(cert *ssh.Certificate) { + _, ok := cert.Extensions["foo"] + assert.False(t, ok) + + val, ok := cert.Extensions["permit-X11-forwarding"] + assert.True(t, ok) + assert.Equals(t, val, "") + + val, ok = cert.Extensions["permit-agent-forwarding"] + assert.True(t, ok) + assert.Equals(t, val, "") + + val, ok = cert.Extensions["permit-port-forwarding"] + assert.True(t, ok) + assert.Equals(t, val, "") + + val, ok = cert.Extensions["permit-pty"] + assert.True(t, ok) + assert.Equals(t, val, "") + + val, ok = cert.Extensions["permit-user-rc"] + assert.True(t, ok) + assert.Equals(t, val, "") + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run() + if err := tc.modifier.Modify(tc.cert); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + tc.valid(tc.cert) + } + } + }) + } +} + func Test_sshCertificateDefaultValidator_Valid(t *testing.T) { pub, _, err := keys.GenerateDefaultKeyPair() assert.FatalError(t, err) @@ -505,7 +956,7 @@ func Test_sshDefaultDuration_Option(t *testing.T) { {"host backdate", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert}}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(30 * 24 * time.Hour)}, false}, {"user validAfter", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(1 * time.Hour)}}, - &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Minute), ValidBefore: unix(17 * time.Hour)}, false}, + &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Hour), ValidBefore: unix(17 * time.Hour)}, false}, {"user validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidBefore: unix(1 * time.Hour)}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(time.Hour)}, false}, {"host validAfter validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}}, diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 407a7a3a..3c55aada 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -3,11 +3,13 @@ package provisioner import ( "context" "encoding/base64" + "net/http" "strconv" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" "golang.org/x/crypto/ssh" ) @@ -99,33 +101,31 @@ func (p *SSHPOP) Init(config Config) error { // claims for case specific downstream parsing. // e.g. a Sign request will auth/validate different fields than a Revoke request. func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { - sshCert, err := ExtractSSHPOPCert(token) + sshCert, jwt, err := ExtractSSHPOPCert(token) if err != nil { - return nil, errors.Wrap(err, "authorizeToken ssh-pop") + return nil, errs.Wrap(http.StatusUnauthorized, err, + "sshpop.authorizeToken; error extracting sshpop header from token") } // Check for revocation. if isRevoked, err := p.db.IsSSHRevoked(strconv.FormatUint(sshCert.Serial, 10)); err != nil { - return nil, errors.Wrap(err, "authorizeToken ssh-pop") + return nil, errs.Wrap(http.StatusInternalServerError, err, + "sshpop.authorizeToken; error checking checking sshpop cert revocation") } else if isRevoked { - return nil, errors.New("authorizeToken ssh-pop: ssh certificate has been revoked") + return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate is revoked")) } - jwt, err := jose.ParseSigned(token) - if err != nil { - return nil, errors.Wrapf(err, "error parsing token") - } // Check validity period of the certificate. n := time.Now() if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { - return nil, errors.New("sshpop certificate validAfter is in the future") + return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future")) } if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { - return nil, errors.New("sshpop certificate validBefore is in the past") + return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past")) } sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) if !ok { - return nil, errors.New("ssh public key could not be cast to ssh CryptoPublicKey") + return nil, errs.InternalServerError(errors.New("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")) } pubKey := sshCryptoPubKey.CryptoPublicKey() @@ -146,7 +146,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa } } if !found { - return nil, errors.New("error: provisioner could could not verify the sshpop header certificate") + return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate")) } // Using the ssh certificates key to validate the claims accomplishes two @@ -156,7 +156,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa // 2. Asserts that the claims are valid - have not been tampered with. var claims sshPOPPayload if err = jwt.Claims(pubKey, &claims); err != nil { - return nil, errors.Wrap(err, "error parsing claims") + return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; error parsing sshpop token claims") } // According to "rfc7519 JSON Web Token" acceptable skew should be no @@ -165,16 +165,17 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa Issuer: p.Name, Time: time.Now().UTC(), }, time.Minute); err != nil { - return nil, errors.Wrapf(err, "invalid token") + return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; invalid sshpop token") } // validate audiences with the defaults if !matchesAudience(claims.Audience, audiences) { - return nil, errors.New("invalid token: invalid audience claim (aud)") + return nil, errs.Unauthorized(errors.Errorf("sshpop.authorizeToken; sshpop token has invalid audience "+ + "claim (aud): expected %s, but got %s", audiences, claims.Audience)) } if claims.Subject == "" { - return nil, errors.New("token subject cannot be empty") + return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty")) } claims.sshCert = sshCert @@ -186,12 +187,13 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { claims, err := p.authorizeToken(token, p.audiences.SSHRevoke) if err != nil { - return err + return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { - return errors.New("token subject must be equivalent to certificate serial number") + return errs.BadRequest(errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject " + + "must be equivalent to sshpop certificate serial number")) } - return err + return nil } // AuthorizeSSHRenew validates the authorization token and extracts/validates @@ -199,10 +201,10 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { claims, err := p.authorizeToken(token, p.audiences.SSHRenew) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { - return nil, errors.New("sshpop AuthorizeSSHRenew: sshpop certificate must be a host ssh certificate") + return nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")) } return claims.sshCert, nil @@ -214,10 +216,10 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { claims, err := p.authorizeToken(token, p.audiences.SSHRekey) if err != nil { - return nil, nil, err + return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } if claims.sshCert.CertType != ssh.HostCert { - return nil, nil, errors.New("sshpop AuthorizeSSHRekey: sshpop certificate must be a host ssh certificate") + return nil, nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")) } return claims.sshCert, []SignOption{ // Validate public key @@ -232,33 +234,34 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate // in the sshpop header. If the header is missing, an error is returned. -func ExtractSSHPOPCert(token string) (*ssh.Certificate, error) { +func ExtractSSHPOPCert(token string) (*ssh.Certificate, *jose.JSONWebToken, error) { jwt, err := jose.ParseSigned(token) if err != nil { - return nil, errors.Wrapf(err, "error parsing token") + return nil, nil, errors.Wrapf(err, "extractSSHPOPCert; error parsing token") } encodedSSHCert, ok := jwt.Headers[0].ExtraHeaders["sshpop"] if !ok { - return nil, errors.New("token missing sshpop header") + return nil, nil, errors.New("extractSSHPOPCert; token missing sshpop header") } encodedSSHCertStr, ok := encodedSSHCert.(string) if !ok { - return nil, errors.New("error unexpected type for sshpop header") + return nil, nil, errors.Errorf("extractSSHPOPCert; error unexpected type for sshpop header: "+ + "want 'string', but got '%T'", encodedSSHCert) } sshCertBytes, err := base64.StdEncoding.DecodeString(encodedSSHCertStr) if err != nil { - return nil, errors.Wrap(err, "error decoding sshpop header") + return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error base64 decoding sshpop header") } sshPub, err := ssh.ParsePublicKey(sshCertBytes) if err != nil { - return nil, errors.Wrap(err, "error parsing ssh public key") + return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error parsing ssh public key") } sshCert, ok := sshPub.(*ssh.Certificate) if !ok { - return nil, errors.New("error converting ssh public key to ssh certificate") + return nil, nil, errors.New("extractSSHPOPCert; error converting ssh public key to ssh certificate") } - return sshCert, nil + return sshCert, jwt, nil } func bytesForSigning(cert *ssh.Certificate) []byte { diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go new file mode 100644 index 00000000..32f58879 --- /dev/null +++ b/authority/provisioner/sshpop_test.go @@ -0,0 +1,684 @@ +package provisioner + +import ( + "context" + "crypto" + "crypto/rand" + "encoding/base64" + "net/http" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" + "github.com/smallstep/cli/crypto/pemutil" + "github.com/smallstep/cli/jose" + "golang.org/x/crypto/ssh" +) + +func TestSSHPOP_Getters(t *testing.T) { + p, err := generateSSHPOP() + assert.FatalError(t, err) + id := "sshpop/" + p.Name + if got := p.GetID(); got != id { + t.Errorf("SSHPOP.GetID() = %v, want %v", got, id) + } + if got := p.GetName(); got != p.Name { + t.Errorf("SSHPOP.GetName() = %v, want %v", got, p.Name) + } + if got := p.GetType(); got != TypeSSHPOP { + t.Errorf("SSHPOP.GetType() = %v, want %v", got, TypeSSHPOP) + } + kid, key, ok := p.GetEncryptedKey() + if kid != "" || key != "" || ok == true { + t.Errorf("SSHPOP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", + kid, key, ok, "", "", false) + } +} + +func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) + if err != nil { + return nil, nil, err + } + cert.Key, err = ssh.NewPublicKey(jwk.Public().Key) + if err != nil { + return nil, nil, err + } + if err = cert.SignCert(rand.Reader, signer); err != nil { + return nil, nil, err + } + return cert, jwk, nil +} + +func generateSSHPOPToken(p Interface, cert *ssh.Certificate, jwk *jose.JSONWebKey) (string, error) { + return generateToken("foo", p.GetName(), testAudiences.Sign[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) +} + +func TestSSHPOP_authorizeToken(t *testing.T) { + key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + + type test struct { + p *SSHPOP + token string + err error + code int + } + tests := map[string]func(*testing.T) test{ + "fail/bad-token": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), + } + }, + "fail/error-revoked-db-check": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, errors.New("force") + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) + assert.FatalError(t, err) + tok, err := generateSSHPOPToken(p, cert, jwk) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusInternalServerError, + err: errors.New("sshpop.authorizeToken; error checking checking sshpop cert revocation: force"), + } + }, + "fail/cert-already-revoked": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return true, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) + assert.FatalError(t, err) + tok, err := generateSSHPOPToken(p, cert, jwk) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("sshpop.authorizeToken; sshpop certificate is revoked"), + } + }, + "fail/cert-not-yet-valid": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{ + CertType: ssh.UserCert, + ValidAfter: uint64(time.Now().Add(time.Minute).Unix()), + }, sshSigner) + assert.FatalError(t, err) + tok, err := generateSSHPOPToken(p, cert, jwk) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future"), + } + }, + "fail/cert-past-validity": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{ + CertType: ssh.UserCert, + ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()), + }, sshSigner) + assert.FatalError(t, err) + tok, err := generateSSHPOPToken(p, cert, jwk) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past"), + } + }, + "fail/no-signer-found": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + tok, err := generateSSHPOPToken(p, cert, jwk) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate"), + } + }, + "fail/error-parsing-claims-bad-sig": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) + assert.FatalError(t, err) + otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + tok, err := generateSSHPOPToken(p, cert, otherJWK) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("sshpop.authorizeToken; error parsing sshpop token claims"), + } + }, + "fail/invalid-claims-issuer": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) + assert.FatalError(t, err) + tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("sshpop.authorizeToken; invalid sshpop token"), + } + }, + "fail/invalid-audience": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) + assert.FatalError(t, err) + tok, err := generateToken("foo", p.GetName(), "invalid-aud", "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("sshpop.authorizeToken; sshpop token has invalid audience claim (aud)"), + } + }, + "fail/empty-subject": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) + assert.FatalError(t, err) + tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty"), + } + }, + "ok": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) + assert.FatalError(t, err) + tok, err := generateSSHPOPToken(p, cert, jwk) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.NotNil(t, claims) + } + } + }) + } +} + +func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { + key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + + type test struct { + p *SSHPOP + token string + err error + code int + } + tests := map[string]func(*testing.T) test{ + "fail/bad-token": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("sshpop.AuthorizeSSHRevoke: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), + } + }, + "fail/subject-not-equal-serial": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) + assert.FatalError(t, err) + tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusBadRequest, + err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"), + } + }, + "ok": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner) + assert.FatalError(t, err) + tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { + key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key") + assert.FatalError(t, err) + userSigner, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer") + sshUserSigner, err := ssh.NewSignerFromSigner(userSigner) + assert.FatalError(t, err) + + hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + hostSigner, ok := hostKey.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer") + sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner) + assert.FatalError(t, err) + + type test struct { + p *SSHPOP + token string + cert *ssh.Certificate + err error + code int + } + tests := map[string]func(*testing.T) test{ + "fail/bad-token": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("sshpop.AuthorizeSSHRenew: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), + } + }, + "fail/not-host-cert": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner) + assert.FatalError(t, err) + tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusBadRequest, + err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"), + } + }, + "ok": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) + assert.FatalError(t, err) + tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + cert: cert, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.cert.Nonce, cert.Nonce) + } + } + }) + } +} + +func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { + key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key") + assert.FatalError(t, err) + userSigner, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer") + sshUserSigner, err := ssh.NewSignerFromSigner(userSigner) + assert.FatalError(t, err) + + hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + hostSigner, ok := hostKey.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer") + sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner) + assert.FatalError(t, err) + + type test struct { + p *SSHPOP + token string + cert *ssh.Certificate + err error + code int + } + tests := map[string]func(*testing.T) test{ + "fail/bad-token": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("sshpop.AuthorizeSSHRekey: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), + } + }, + "fail/not-host-cert": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner) + assert.FatalError(t, err) + tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusBadRequest, + err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"), + } + }, + "ok": func(t *testing.T) test { + p, err := generateSSHPOP() + assert.FatalError(t, err) + p.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) + assert.FatalError(t, err) + tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + cert: cert, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Len(t, 3, opts) + for _, o := range opts { + switch v := o.(type) { + case *sshDefaultPublicKeyValidator: + case *sshCertificateDefaultValidator: + case *sshCertificateValidityValidator: + assert.Equals(t, v.Claimer, tc.p.claimer) + default: + assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + } + } + assert.Equals(t, tc.cert.Nonce, cert.Nonce) + } + } + }) + } +} + +func TestSSHPOP_ExtractSSHPOPCert(t *testing.T) { + hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + hostSigner, ok := hostKey.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer") + sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner) + assert.FatalError(t, err) + + type test struct { + token string + cert *ssh.Certificate + jwk *jose.JSONWebKey + err error + } + tests := map[string]func(*testing.T) test{ + "fail/bad-token": func(t *testing.T) test { + return test{ + token: "foo", + err: errors.New("extractSSHPOPCert; error parsing token"), + } + }, + "fail/sshpop-missing": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + tok, err := generateToken("sub", "sshpop-provisioner", testAudiences.SSHRekey[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk) + assert.FatalError(t, err) + return test{ + token: tok, + err: errors.New("extractSSHPOPCert; token missing sshpop header"), + } + }, + "fail/wrong-sshpop-type": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error { + so.WithHeader("sshpop", 12345) + return nil + }) + assert.FatalError(t, err) + return test{ + token: tok, + err: errors.New("extractSSHPOPCert; error unexpected type for sshpop header: "), + } + }, + "fail/base64decode-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error { + so.WithHeader("sshpop", "!@#$%^&*") + return nil + }) + assert.FatalError(t, err) + return test{ + token: tok, + err: errors.New("extractSSHPOPCert; error base64 decoding sshpop header: illegal base64"), + } + }, + "fail/parsing-sshpop-pubkey": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error { + so.WithHeader("sshpop", base64.StdEncoding.EncodeToString([]byte("foo"))) + return nil + }) + assert.FatalError(t, err) + return test{ + token: tok, + err: errors.New("extractSSHPOPCert; error parsing ssh public key"), + } + }, + "ok": func(t *testing.T) test { + cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) + + assert.FatalError(t, err) + tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return test{ + token: tok, + jwk: jwk, + cert: cert, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if cert, jwt, err := ExtractSSHPOPCert(tc.token); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.cert.Nonce, cert.Nonce) + assert.Equals(t, tc.jwk.KeyID, jwt.Headers[0].KeyID) + } + } + }) + } +} diff --git a/authority/provisioner/testdata/bar.pub b/authority/provisioner/testdata/certs/bar.pub similarity index 100% rename from authority/provisioner/testdata/bar.pub rename to authority/provisioner/testdata/certs/bar.pub diff --git a/authority/provisioner/testdata/ecdsa.csr b/authority/provisioner/testdata/certs/ecdsa.csr similarity index 100% rename from authority/provisioner/testdata/ecdsa.csr rename to authority/provisioner/testdata/certs/ecdsa.csr diff --git a/authority/provisioner/testdata/ed25519.csr b/authority/provisioner/testdata/certs/ed25519.csr similarity index 100% rename from authority/provisioner/testdata/ed25519.csr rename to authority/provisioner/testdata/certs/ed25519.csr diff --git a/authority/provisioner/testdata/foo.pub b/authority/provisioner/testdata/certs/foo.pub similarity index 100% rename from authority/provisioner/testdata/foo.pub rename to authority/provisioner/testdata/certs/foo.pub diff --git a/authority/provisioner/testdata/root_ca.crt b/authority/provisioner/testdata/certs/root_ca.crt similarity index 100% rename from authority/provisioner/testdata/root_ca.crt rename to authority/provisioner/testdata/certs/root_ca.crt diff --git a/authority/provisioner/testdata/rsa.csr b/authority/provisioner/testdata/certs/rsa.csr similarity index 100% rename from authority/provisioner/testdata/rsa.csr rename to authority/provisioner/testdata/certs/rsa.csr diff --git a/authority/provisioner/testdata/short-rsa.csr b/authority/provisioner/testdata/certs/short-rsa.csr similarity index 100% rename from authority/provisioner/testdata/short-rsa.csr rename to authority/provisioner/testdata/certs/short-rsa.csr diff --git a/authority/provisioner/testdata/certs/ssh_host_ca_key.pub b/authority/provisioner/testdata/certs/ssh_host_ca_key.pub new file mode 100644 index 00000000..aa5685da --- /dev/null +++ b/authority/provisioner/testdata/certs/ssh_host_ca_key.pub @@ -0,0 +1 @@ +ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY= diff --git a/authority/provisioner/testdata/certs/ssh_user_ca_key.pub b/authority/provisioner/testdata/certs/ssh_user_ca_key.pub new file mode 100644 index 00000000..5909ce43 --- /dev/null +++ b/authority/provisioner/testdata/certs/ssh_user_ca_key.pub @@ -0,0 +1 @@ +ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y= diff --git a/authority/provisioner/testdata/x5c-leaf.crt b/authority/provisioner/testdata/certs/x5c-leaf.crt similarity index 100% rename from authority/provisioner/testdata/x5c-leaf.crt rename to authority/provisioner/testdata/certs/x5c-leaf.crt diff --git a/authority/provisioner/testdata/bar.priv b/authority/provisioner/testdata/secrets/bar.priv similarity index 100% rename from authority/provisioner/testdata/bar.priv rename to authority/provisioner/testdata/secrets/bar.priv diff --git a/authority/provisioner/testdata/secrets/bar_host_ssh_key b/authority/provisioner/testdata/secrets/bar_host_ssh_key new file mode 100644 index 00000000..7662c1a6 --- /dev/null +++ b/authority/provisioner/testdata/secrets/bar_host_ssh_key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIHzAUYu3h8e1gL5ONGZo+lghJJa9rl1TvP2UlqDXazxvoAoGCCqGSM49 +AwEHoUQDQgAEOLScS+1Yzmqdyots9lSC0tzTSXUXEgyOD9wYrQ0BqnVZtBXlQw1p +m3fnF/7Ehl6bD1YZWjrF1t+IBZQMq1uBBw== +-----END EC PRIVATE KEY----- diff --git a/authority/provisioner/testdata/ecdsa.key b/authority/provisioner/testdata/secrets/ecdsa.key similarity index 100% rename from authority/provisioner/testdata/ecdsa.key rename to authority/provisioner/testdata/secrets/ecdsa.key diff --git a/authority/provisioner/testdata/ed25519.key b/authority/provisioner/testdata/secrets/ed25519.key similarity index 100% rename from authority/provisioner/testdata/ed25519.key rename to authority/provisioner/testdata/secrets/ed25519.key diff --git a/authority/provisioner/testdata/foo.priv b/authority/provisioner/testdata/secrets/foo.priv similarity index 100% rename from authority/provisioner/testdata/foo.priv rename to authority/provisioner/testdata/secrets/foo.priv diff --git a/authority/provisioner/testdata/secrets/foo_user_ssh_key b/authority/provisioner/testdata/secrets/foo_user_ssh_key new file mode 100644 index 00000000..8bda30c6 --- /dev/null +++ b/authority/provisioner/testdata/secrets/foo_user_ssh_key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEINWGD2xneE43YeytQzORItISxv6d/oH+9TXvDKHo6TyXoAoGCCqGSM49 +AwEHoUQDQgAEVK/EtXgVV7+7ppnQSjCtI5qb/gIGnQUF4i//F/JKKho7kRNyMDSn +BP3kndiv8Yfxg4PsyIRY5ZofbEo5eJE6bg== +-----END EC PRIVATE KEY----- diff --git a/authority/provisioner/testdata/rsa.key b/authority/provisioner/testdata/secrets/rsa.key similarity index 100% rename from authority/provisioner/testdata/rsa.key rename to authority/provisioner/testdata/secrets/rsa.key diff --git a/authority/provisioner/testdata/secrets/ssh_host_ca_key b/authority/provisioner/testdata/secrets/ssh_host_ca_key new file mode 100644 index 00000000..7a7e4c44 --- /dev/null +++ b/authority/provisioner/testdata/secrets/ssh_host_ca_key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49 +AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+ +MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg== +-----END EC PRIVATE KEY----- diff --git a/authority/provisioner/testdata/secrets/ssh_user_ca_key b/authority/provisioner/testdata/secrets/ssh_user_ca_key new file mode 100644 index 00000000..92d35ec2 --- /dev/null +++ b/authority/provisioner/testdata/secrets/ssh_user_ca_key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49 +AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h +YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg== +-----END EC PRIVATE KEY----- diff --git a/authority/provisioner/testdata/x5c-leaf.key b/authority/provisioner/testdata/secrets/x5c-leaf.key similarity index 100% rename from authority/provisioner/testdata/x5c-leaf.key rename to authority/provisioner/testdata/secrets/x5c-leaf.key diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 76c9a567..7d200d33 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -19,6 +19,7 @@ import ( "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/jose" + "golang.org/x/crypto/ssh" ) var ( @@ -47,24 +48,6 @@ var ( } ) -func provisionerClaims() *Claims { - ddr := false - des := true - return &Claims{ - MinTLSDur: &Duration{5 * time.Minute}, - MaxTLSDur: &Duration{24 * time.Hour}, - DefaultTLSDur: &Duration{24 * time.Hour}, - DisableRenewal: &ddr, - MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &Duration{Duration: 4 * time.Hour}, - MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &des, - } -} - const awsTestCertificate = `-----BEGIN CERTIFICATE----- MIICFTCCAX6gAwIBAgIRAKmbVVYAl/1XEqRfF3eJ97MwDQYJKoZIhvcNAQELBQAw GDEWMBQGA1UEAxMNQVdTIFRlc3QgQ2VydDAeFw0xOTA0MjQyMjU3MzlaFw0yOTA0 @@ -204,7 +187,7 @@ func generateJWK() (*JWK, error) { } func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { - fooPubB, err := ioutil.ReadFile("./testdata/foo.pub") + fooPubB, err := ioutil.ReadFile("./testdata/certs/foo.pub") if err != nil { return nil, err } @@ -212,7 +195,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { if err != nil { return nil, err } - barPubB, err := ioutil.ReadFile("./testdata/bar.pub") + barPubB, err := ioutil.ReadFile("./testdata/certs/bar.pub") if err != nil { return nil, err } @@ -240,6 +223,46 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { }, nil } +func generateSSHPOP() (*SSHPOP, error) { + name, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + claimer, err := NewClaimer(nil, globalProvisionerClaims) + if err != nil { + return nil, err + } + + userB, err := ioutil.ReadFile("./testdata/certs/ssh_user_ca_key.pub") + if err != nil { + return nil, err + } + userKey, _, _, _, err := ssh.ParseAuthorizedKey(userB) + if err != nil { + return nil, err + } + hostB, err := ioutil.ReadFile("./testdata/certs/ssh_host_ca_key.pub") + if err != nil { + return nil, err + } + hostKey, _, _, _, err := ssh.ParseAuthorizedKey(hostB) + if err != nil { + return nil, err + } + + return &SSHPOP{ + Name: name, + Type: "SSHPOP", + Claims: &globalProvisionerClaims, + audiences: testAudiences, + claimer: claimer, + sshPubKeys: &SSHKeys{ + UserKeys: []ssh.PublicKey{userKey}, + HostKeys: []ssh.PublicKey{hostKey}, + }, + }, nil +} + func generateX5C(root []byte) (*X5C, error) { if root == nil { root = []byte(`-----BEGIN CERTIFICATE----- @@ -589,6 +612,13 @@ func withX5CHdr(certs []*x509.Certificate) tokOption { } } +func withSSHPOPFile(cert *ssh.Certificate) tokOption { + return func(so *jose.SignerOptions) error { + so.WithHeader("sshpop", base64.StdEncoding.EncodeToString(cert.Marshal())) + return nil + } +} + func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") @@ -630,6 +660,24 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T return jose.Signed(sig).Claims(claims).CompactSerialize() } +func generateX5CSSHToken(jwk *jose.JSONWebKey, claims *x5cPayload, tokOpts ...tokOption) (string, error) { + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("kid", jwk.KeyID) + + for _, o := range tokOpts { + if err := o(so); err != nil { + return "", err + } + } + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) + if err != nil { + return "", err + } + return jose.Signed(sig).Claims(claims).CompactSerialize() +} + func getK8sSAPayload() *k8sSAPayload { return &k8sSAPayload{ Claims: jose.Claims{ diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 1be728db..692cd963 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -4,9 +4,11 @@ import ( "context" "crypto/x509" "encoding/pem" + "net/http" "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/jose" ) @@ -121,19 +123,20 @@ func (p *X5C) Init(config Config) error { func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { - return nil, errors.Wrapf(err, "error parsing token") + return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c token") } verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{ Roots: p.rootPool, }) if err != nil { - return nil, errors.Wrap(err, "error verifying x5c certificate chain") + return nil, errs.Wrap(http.StatusUnauthorized, err, + "x5c.authorizeToken; error verifying x5c certificate chain in token") } leaf := verifiedChains[0][0] if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 { - return nil, errors.New("certificate used to sign x5c token cannot be used for digital signature") + return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")) } // Using the leaf certificates key to validate the claims accomplishes two @@ -143,7 +146,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err // 2. Asserts that the claims are valid - have not been tampered with. var claims x5cPayload if err = jwt.Claims(leaf.PublicKey, &claims); err != nil { - return nil, errors.Wrap(err, "error parsing claims") + return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c claims") } // According to "rfc7519 JSON Web Token" acceptable skew should be no @@ -152,16 +155,17 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err Issuer: p.Name, Time: time.Now().UTC(), }, time.Minute); err != nil { - return nil, errors.Wrapf(err, "invalid token") + return nil, errs.Wrapf(http.StatusUnauthorized, err, "x5c.authorizeToken; invalid x5c claims") } // validate audiences with the defaults if !matchesAudience(claims.Audience, audiences) { - return nil, errors.New("invalid token: invalid audience claim (aud)") + return nil, errs.Unauthorized(errors.Errorf("x5c.authorizeToken; x5c token has invalid audience "+ + "claim (aud); expected %s, but got %s", audiences, claims.Audience)) } if claims.Subject == "" { - return nil, errors.New("token subject cannot be empty") + return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; x5c token subject cannot be empty")) } // Save the verified chains on the x5c payload object. @@ -173,14 +177,14 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err // revoke the certificate with serial number in the `sub` property. func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error { _, err := p.authorizeToken(token, p.audiences.Revoke) - return err + return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.audiences.Sign) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") } // NOTE: This is for backwards compatibility with older versions of cli @@ -209,7 +213,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // AuthorizeRenew returns an error if the renewal is disabled. func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) + return errs.Unauthorized(errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID())) } return nil } @@ -217,16 +221,16 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID()) + return nil, errs.Unauthorized(errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID())) } claims, err := p.authorizeToken(token, p.audiences.SSHSign) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") } if claims.Step == nil || claims.Step.SSH == nil { - return nil, errors.New("authorization token must be an SSH provisioning token") + return nil, errs.Unauthorized(errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token")) } opts := claims.Step.SSH @@ -245,18 +249,18 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, } t := now() if !opts.ValidAfter.IsZero() { - signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix())) + signOptions = append(signOptions, sshCertValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix())) } if !opts.ValidBefore.IsZero() { - signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) + signOptions = append(signOptions, sshCertValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) } // Make sure to define the the KeyID if opts.KeyID == "" { - signOptions = append(signOptions, sshCertificateKeyIDModifier(claims.Subject)) + signOptions = append(signOptions, sshCertKeyIDModifier(claims.Subject)) } // Default to a user certificate with no principals if not set - signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert}) + signOptions = append(signOptions, sshCertDefaultsModifier{CertType: SSHUserCert}) return append(signOptions, // Set the default extensions. diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 65147d24..775f3202 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -2,14 +2,16 @@ package provisioner import ( "context" - "crypto/x509" "net" + "net/http" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/pemutil" + "github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/jose" ) @@ -151,9 +153,15 @@ M46l92gdOozT } func TestX5C_authorizeToken(t *testing.T) { + x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") + assert.FatalError(t, err) + x5cJWK, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key") + assert.FatalError(t, err) + type test struct { p *X5C token string + code int err error } tests := map[string]func(*testing.T) test{ @@ -163,7 +171,8 @@ func TestX5C_authorizeToken(t *testing.T) { return test{ p: p, token: "foo", - err: errors.New("error parsing token"), + code: http.StatusUnauthorized, + err: errors.New("x5c.authorizeToken; error parsing x5c token"), } }, "fail/invalid-cert-chain": func(t *testing.T) test { @@ -190,7 +199,8 @@ a5wpg+9s6QIgHIW6L60F8klQX+EO3o0SBqLeNcaskA4oSZsKjEdpSGo= return test{ p: p, token: tok, - err: errors.New("error verifying x5c certificate chain: x509: certificate signed by unknown authority"), + code: http.StatusUnauthorized, + err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"), } }, "fail/doubled-up-self-signed-cert": func(t *testing.T) test { @@ -228,7 +238,8 @@ EXAHTA9L return test{ p: p, token: tok, - err: errors.New("error verifying x5c certificate chain: x509: certificate signed by unknown authority"), + code: http.StatusUnauthorized, + err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"), } }, "fail/digital-signature-ext-required": func(t *testing.T) test { @@ -269,7 +280,8 @@ lgsqsR63is+0YQ== return test{ p: p, token: tok, - err: errors.New("certificate used to sign x5c token cannot be used for digital signature"), + code: http.StatusUnauthorized, + err: errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature"), } }, "fail/signature-does-not-match-x5c-pub-key": func(t *testing.T) test { @@ -309,74 +321,58 @@ lgsqsR63is+0YQ== return test{ p: p, token: tok, - err: errors.New("error parsing claims: square/go-jose: error in cryptographic primitive"), + code: http.StatusUnauthorized, + err: errors.New("x5c.authorizeToken; error parsing x5c claims"), } }, "fail/invalid-issuer": func(t *testing.T) test { - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") - assert.FatalError(t, err) - p, err := generateX5C(nil) assert.FatalError(t, err) tok, err := generateToken("", "foobar", testAudiences.Sign[0], "", - []string{"test.smallstep.com"}, time.Now(), jwk, - withX5CHdr(certs)) + []string{"test.smallstep.com"}, time.Now(), x5cJWK, + withX5CHdr(x5cCerts)) assert.FatalError(t, err) return test{ p: p, token: tok, - err: errors.New("invalid token: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"), + code: http.StatusUnauthorized, + err: errors.New("x5c.authorizeToken; invalid x5c claims"), } }, "fail/invalid-audience": func(t *testing.T) test { - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") - assert.FatalError(t, err) - p, err := generateX5C(nil) assert.FatalError(t, err) tok, err := generateToken("", p.GetName(), "foobar", "", - []string{"test.smallstep.com"}, time.Now(), jwk, - withX5CHdr(certs)) + []string{"test.smallstep.com"}, time.Now(), x5cJWK, + withX5CHdr(x5cCerts)) assert.FatalError(t, err) return test{ p: p, token: tok, - err: errors.New("invalid token: invalid audience claim (aud)"), + code: http.StatusUnauthorized, + err: errors.New("x5c.authorizeToken; x5c token has invalid audience claim (aud)"), } }, "fail/empty-subject": func(t *testing.T) test { - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") - assert.FatalError(t, err) - p, err := generateX5C(nil) assert.FatalError(t, err) tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "", - []string{"test.smallstep.com"}, time.Now(), jwk, - withX5CHdr(certs)) + []string{"test.smallstep.com"}, time.Now(), x5cJWK, + withX5CHdr(x5cCerts)) assert.FatalError(t, err) return test{ p: p, token: tok, - err: errors.New("token subject cannot be empty"), + code: http.StatusUnauthorized, + err: errors.New("x5c.authorizeToken; x5c token subject cannot be empty"), } }, "ok": func(t *testing.T) test { - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") - assert.FatalError(t, err) - p, err := generateX5C(nil) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", - []string{"test.smallstep.com"}, time.Now(), jwk, - withX5CHdr(certs)) + []string{"test.smallstep.com"}, time.Now(), x5cJWK, + withX5CHdr(x5cCerts)) assert.FatalError(t, err) return test{ p: p, @@ -389,6 +385,9 @@ lgsqsR63is+0YQ== tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { @@ -402,10 +401,15 @@ lgsqsR63is+0YQ== } func TestX5C_AuthorizeSign(t *testing.T) { + certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") + assert.FatalError(t, err) + jwk, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key") + assert.FatalError(t, err) + type test struct { p *X5C token string - ctx context.Context + code int err error dns []string emails []string @@ -418,56 +422,11 @@ func TestX5C_AuthorizeSign(t *testing.T) { return test{ p: p, token: "foo", - ctx: NewContextWithMethod(context.Background(), SignMethod), - err: errors.New("error parsing token"), - } - }, - "fail/ssh/disabled": func(t *testing.T) test { - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") - assert.FatalError(t, err) - - p, err := generateX5C(nil) - assert.FatalError(t, err) - p.claimer.claims = provisionerClaims() - *p.claimer.claims.EnableSSHCA = false - tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", - []string{"test.smallstep.com"}, time.Now(), jwk, - withX5CHdr(certs)) - assert.FatalError(t, err) - return test{ - p: p, - ctx: NewContextWithMethod(context.Background(), SignSSHMethod), - token: tok, - err: errors.Errorf("ssh ca is disabled for provisioner x5c/%s", p.GetName()), - } - }, - "fail/ssh/invalid-token": func(t *testing.T) test { - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") - assert.FatalError(t, err) - - p, err := generateX5C(nil) - assert.FatalError(t, err) - tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", - []string{"test.smallstep.com"}, time.Now(), jwk, - withX5CHdr(certs)) - assert.FatalError(t, err) - return test{ - p: p, - ctx: NewContextWithMethod(context.Background(), SignSSHMethod), - token: tok, - err: errors.New("authorization token must be an SSH provisioning token"), + code: http.StatusUnauthorized, + err: errors.New("x5c.AuthorizeSign: x5c.authorizeToken; error parsing x5c token"), } }, "ok/empty-sans": func(t *testing.T) test { - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") - assert.FatalError(t, err) - p, err := generateX5C(nil) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", @@ -476,7 +435,6 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.FatalError(t, err) return test{ p: p, - ctx: NewContextWithMethod(context.Background(), SignMethod), token: tok, dns: []string{"foo"}, emails: []string{}, @@ -484,11 +442,6 @@ func TestX5C_AuthorizeSign(t *testing.T) { } }, "ok/multi-sans": func(t *testing.T) test { - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") - assert.FatalError(t, err) - p, err := generateX5C(nil) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", @@ -497,7 +450,6 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.FatalError(t, err) return test{ p: p, - ctx: NewContextWithMethod(context.Background(), SignMethod), token: tok, dns: []string{"foo"}, emails: []string{"max@smallstep.com"}, @@ -508,8 +460,11 @@ func TestX5C_AuthorizeSign(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if opts, err := tc.p.AuthorizeSign(tc.ctx, tc.token); err != nil { + if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { @@ -554,126 +509,11 @@ func TestX5C_AuthorizeSign(t *testing.T) { } } -func TestX5C_AuthorizeSSHSign(t *testing.T) { - _, fn := mockNow() - defer fn() - type test struct { - p *X5C - token string - claims *x5cPayload - err error - } - tests := map[string]func(*testing.T) test{ - "fail/no-Step-claim": func(t *testing.T) test { - p, err := generateX5C(nil) - assert.FatalError(t, err) - return test{ - p: p, - claims: new(x5cPayload), - err: errors.New("authorization token must be an SSH provisioning token"), - } - }, - "fail/no-SSH-subattribute-in-claims": func(t *testing.T) test { - p, err := generateX5C(nil) - assert.FatalError(t, err) - return test{ - p: p, - claims: &x5cPayload{Step: new(stepPayload)}, - err: errors.New("authorization token must be an SSH provisioning token"), - } - }, - "ok/with-claims": func(t *testing.T) test { - p, err := generateX5C(nil) - assert.FatalError(t, err) - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - return test{ - p: p, - claims: &x5cPayload{ - Step: &stepPayload{SSH: &SSHOptions{ - CertType: SSHHostCert, - Principals: []string{"max", "mariano", "alan"}, - ValidAfter: TimeDuration{d: 5 * time.Minute}, - ValidBefore: TimeDuration{d: 10 * time.Minute}, - }}, - Claims: jose.Claims{Subject: "foo"}, - chains: [][]*x509.Certificate{certs}, - }, - } - }, - "ok/without-claims": func(t *testing.T) test { - p, err := generateX5C(nil) - assert.FatalError(t, err) - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") - assert.FatalError(t, err) - return test{ - p: p, - claims: &x5cPayload{ - Step: &stepPayload{SSH: &SSHOptions{}}, - Claims: jose.Claims{Subject: "foo"}, - chains: [][]*x509.Certificate{certs}, - }, - } - }, - } - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - tc := tt(t) - if opts, err := tc.p.AuthorizeSSHSign(context.TODO(), tc.token); err != nil { - if assert.NotNil(t, tc.err) { - assert.HasPrefix(t, err.Error(), tc.err.Error()) - } - } else { - if assert.Nil(t, tc.err) { - if assert.NotNil(t, opts) { - tot := 0 - nw := now() - for _, o := range opts { - switch v := o.(type) { - case sshCertificateOptionsValidator: - tc.claims.Step.SSH.ValidAfter.t = time.Time{} - tc.claims.Step.SSH.ValidBefore.t = time.Time{} - assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH) - case sshCertificateKeyIDModifier: - assert.Equals(t, string(v), "foo") - case sshCertTypeModifier: - assert.Equals(t, string(v), tc.claims.Step.SSH.CertType) - case sshCertPrincipalsModifier: - assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals) - case sshCertificateValidAfterModifier: - assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix()) - case sshCertificateValidBeforeModifier: - assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix()) - case sshCertificateDefaultsModifier: - assert.Equals(t, SSHOptions(v), SSHOptions{CertType: SSHUserCert}) - case *sshLimitDuration: - assert.Equals(t, v.Claimer, tc.p.claimer) - assert.Equals(t, v.NotAfter, tc.claims.chains[0][0].NotAfter) - case *sshCertificateValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) - case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator, - *sshCertificateDefaultValidator: - default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) - } - tot++ - } - if len(tc.claims.Step.SSH.CertType) > 0 { - assert.Equals(t, tot, 12) - } else { - assert.Equals(t, tot, 8) - } - } - } - } - }) - } -} - func TestX5C_AuthorizeRevoke(t *testing.T) { type test struct { p *X5C token string + code int err error } tests := map[string]func(*testing.T) test{ @@ -683,13 +523,14 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { return test{ p: p, token: "foo", - err: errors.New("error parsing token"), + code: http.StatusUnauthorized, + err: errors.New("x5c.AuthorizeRevoke: x5c.authorizeToken; error parsing x5c token"), } }, "ok": func(t *testing.T) test { - certs, err := pemutil.ReadCertificateBundle("./testdata/x5c-leaf.crt") + certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") assert.FatalError(t, err) - jwk, err := jose.ParseKey("./testdata/x5c-leaf.key") + jwk, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key") assert.FatalError(t, err) p, err := generateX5C(nil) @@ -709,6 +550,9 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil { if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { @@ -719,33 +563,248 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { } func TestX5C_AuthorizeRenew(t *testing.T) { - p1, err := generateX5C(nil) - assert.FatalError(t, err) - p2, err := generateX5C(nil) - assert.FatalError(t, err) - - // disable renewal - disable := true - p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) - assert.FatalError(t, err) - - type args struct { - cert *x509.Certificate + type test struct { + p *X5C + code int + err error } - tests := []struct { - name string - prov *X5C - args args - wantErr bool - }{ - {"ok", p1, args{nil}, false}, - {"fail", p2, args{nil}, true}, + tests := map[string]func(*testing.T) test{ + "fail/renew-disabled": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + // disable renewal + disable := true + p.Claims = &Claims{DisableRenewal: &disable} + p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + return test{ + p: p, + code: http.StatusUnauthorized, + err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID()), + } + }, + "ok": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + return test{ + p: p, + } + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { - t.Errorf("X5C.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if err := tc.p.AuthorizeRenew(context.TODO(), nil); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestX5C_AuthorizeSSHSign(t *testing.T) { + x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") + assert.FatalError(t, err) + x5cJWK, err := jose.ParseKey("./testdata/secrets/x5c-leaf.key") + assert.FatalError(t, err) + + _, fn := mockNow() + defer fn() + type test struct { + p *X5C + token string + claims *x5cPayload + code int + err error + } + tests := map[string]func(*testing.T) test{ + "fail/sshCA-disabled": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + // disable sshCA + enable := false + p.Claims = &Claims{EnableSSHCA: &enable} + p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID()), + } + }, + "fail/invalid-token": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + return test{ + p: p, + token: "foo", + code: http.StatusUnauthorized, + err: errors.New("x5c.AuthorizeSSHSign: x5c.authorizeToken; error parsing x5c token"), + } + }, + "fail/no-Step-claim": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + tok, err := generateToken("foo", p.GetName(), testAudiences.SSHSign[0], "", + []string{"test.smallstep.com"}, time.Now(), x5cJWK, + withX5CHdr(x5cCerts)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"), + } + }, + "fail/no-SSH-subattribute-in-claims": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + + id, err := randutil.ASCII(64) + assert.FatalError(t, err) + now := time.Now() + claims := &x5cPayload{ + Claims: jose.Claims{ + ID: id, + Subject: "foo", + Issuer: p.GetName(), + IssuedAt: jose.NewNumericDate(now), + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + Audience: []string{testAudiences.SSHSign[0]}, + }, + Step: &stepPayload{}, + } + tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) + assert.FatalError(t, err) + return test{ + p: p, + token: tok, + code: http.StatusUnauthorized, + err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"), + } + }, + "ok/with-claims": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + + id, err := randutil.ASCII(64) + assert.FatalError(t, err) + now := time.Now() + claims := &x5cPayload{ + Claims: jose.Claims{ + ID: id, + Subject: "foo", + Issuer: p.GetName(), + IssuedAt: jose.NewNumericDate(now), + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + Audience: []string{testAudiences.SSHSign[0]}, + }, + Step: &stepPayload{SSH: &SSHOptions{ + CertType: SSHHostCert, + Principals: []string{"max", "mariano", "alan"}, + ValidAfter: TimeDuration{d: 5 * time.Minute}, + ValidBefore: TimeDuration{d: 10 * time.Minute}, + }}, + } + tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) + assert.FatalError(t, err) + return test{ + p: p, + claims: claims, + token: tok, + } + }, + "ok/without-claims": func(t *testing.T) test { + p, err := generateX5C(nil) + assert.FatalError(t, err) + + id, err := randutil.ASCII(64) + assert.FatalError(t, err) + now := time.Now() + claims := &x5cPayload{ + Claims: jose.Claims{ + ID: id, + Subject: "foo", + Issuer: p.GetName(), + IssuedAt: jose.NewNumericDate(now), + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + Audience: []string{testAudiences.SSHSign[0]}, + }, + Step: &stepPayload{SSH: &SSHOptions{}}, + } + tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) + assert.FatalError(t, err) + return test{ + p: p, + claims: claims, + token: tok, + } + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tc := tt(t) + if opts, err := tc.p.AuthorizeSSHSign(context.TODO(), tc.token); err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + if assert.NotNil(t, opts) { + tot := 0 + nw := now() + for _, o := range opts { + switch v := o.(type) { + case sshCertificateOptionsValidator: + tc.claims.Step.SSH.ValidAfter.t = time.Time{} + tc.claims.Step.SSH.ValidBefore.t = time.Time{} + assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH) + case sshCertKeyIDModifier: + assert.Equals(t, string(v), "foo") + case sshCertTypeModifier: + assert.Equals(t, string(v), tc.claims.Step.SSH.CertType) + case sshCertPrincipalsModifier: + assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals) + case sshCertValidAfterModifier: + assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix()) + case sshCertValidBeforeModifier: + assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix()) + case sshCertDefaultsModifier: + assert.Equals(t, SSHOptions(v), SSHOptions{CertType: SSHUserCert}) + case *sshLimitDuration: + assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) + case *sshCertificateValidityValidator: + assert.Equals(t, v.Claimer, tc.p.claimer) + case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator, + *sshCertificateDefaultValidator: + case sshCertKeyIDValidator: + assert.Equals(t, string(v), "foo") + default: + assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + } + tot++ + } + if len(tc.claims.Step.SSH.CertType) > 0 { + assert.Equals(t, tot, 13) + } else { + assert.Equals(t, tot, 9) + } + } + } } }) } diff --git a/authority/ssh.go b/authority/ssh.go index cfd5ed37..5d80a427 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -122,10 +122,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) { // GetSSHConfig returns rendered templates for clients (user) or servers (host). func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) { if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil { - return nil, &apiError{ - err: errors.New("getSSHConfig: ssh is not configured"), - code: http.StatusNotFound, - } + return nil, errs.NotFound(errors.New("getSSHConfig: ssh is not configured")) } var ts []templates.Template @@ -139,10 +136,7 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template ts = a.config.Templates.SSH.Host } default: - return nil, &apiError{ - err: errors.Errorf("getSSHConfig: type %s is not valid", typ), - code: http.StatusBadRequest, - } + return nil, errs.BadRequest(errors.Errorf("getSSHConfig: type %s is not valid", typ)) } // Merge user and default data @@ -174,7 +168,8 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template // hostname. func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) { if a.sshBastionFunc != nil { - return a.sshBastionFunc(user, hostname) + bs, err := a.sshBastionFunc(user, hostname) + return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion") } if a.config.SSH != nil { if a.config.SSH.Bastion != nil && a.config.SSH.Bastion.Hostname != "" { @@ -182,26 +177,7 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error } return nil, nil } - return nil, &apiError{ - err: errors.New("getSSHBastion: ssh is not configured"), - code: http.StatusNotFound, - } -} - -// authorizeSSHSign loads the provisioner from the token, checks that it has not -// been used again and calls the provisioner AuthorizeSSHSign method. Returns a -// list of methods to apply to the signing flow. -func (a *Authority) authorizeSSHSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - var errContext = apiCtx{"ott": ott} - p, err := a.authorizeToken(ctx, ott) - if err != nil { - return nil, &apiError{errors.Wrap(err, "authorizeSSHSign"), http.StatusUnauthorized, errContext} - } - opts, err := p.AuthorizeSSHSign(ctx, ott) - if err != nil { - return nil, &apiError{errors.Wrap(err, "authorizeSSHSign"), http.StatusUnauthorized, errContext} - } - return opts, nil + return nil, errs.NotFound(errors.New("authority.GetSSHBastion; ssh is not configured")) } // SignSSH creates a signed SSH certificate with the given public key and options. @@ -226,27 +202,21 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign // validate the given SSHOptions case provisioner.SSHCertificateOptionsValidator: if err := o.Valid(opts); err != nil { - return nil, &apiError{err: err, code: http.StatusForbidden} + return nil, errs.Forbidden(err) } default: - return nil, &apiError{ - err: errors.Errorf("signSSH: invalid extra option type %T", o), - code: http.StatusInternalServerError, - } + return nil, errs.InternalServerError(errors.Errorf("signSSH: invalid extra option type %T", o)) } } nonce, err := randutil.ASCII(32) if err != nil { - return nil, &apiError{err: err, code: http.StatusInternalServerError} + return nil, errs.InternalServerError(err) } var serial uint64 if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "signSSH: error reading random number"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error reading random number") } // Build base certificate with the key and some random values @@ -258,13 +228,13 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign // Use opts to modify the certificate if err := opts.Modify(cert); err != nil { - return nil, &apiError{err: err, code: http.StatusForbidden} + return nil, errs.Forbidden(err) } // Use provisioner modifiers for _, m := range mods { if err := m.Modify(cert); err != nil { - return nil, &apiError{err: err, code: http.StatusForbidden} + return nil, errs.Forbidden(err) } } @@ -273,25 +243,16 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign switch cert.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { - return nil, &apiError{ - err: errors.New("signSSH: user certificate signing is not enabled"), - code: http.StatusNotImplemented, - } + return nil, errs.NotImplemented(errors.New("signSSH: user certificate signing is not enabled")) } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { - return nil, &apiError{ - err: errors.New("signSSH: host certificate signing is not enabled"), - code: http.StatusNotImplemented, - } + return nil, errs.NotImplemented(errors.New("signSSH: host certificate signing is not enabled")) } signer = a.sshCAHostCertSignKey default: - return nil, &apiError{ - err: errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType), - code: http.StatusInternalServerError, - } + return nil, errs.InternalServerError(errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType)) } cert.SignatureKey = signer.PublicKey() @@ -302,71 +263,38 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign // Sign the certificate sig, err := signer.Sign(rand.Reader, data) if err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "signSSH: error signing certificate"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") } cert.Signature = sig // User provisioners validators for _, v := range validators { if err := v.Valid(cert); err != nil { - return nil, &apiError{err: err, code: http.StatusForbidden} + return nil, errs.Forbidden(err) } } if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { - return nil, &apiError{ - err: errors.Wrap(err, "signSSH: error storing certificate in db"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error storing certificate in db") } return cert, nil } -// authorizeSSHRenew authorizes an SSH certificate renewal request, by -// validating the contents of an SSHPOP token. -func (a *Authority) authorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { - errContext := map[string]interface{}{"ott": token} - - p, err := a.authorizeToken(ctx, token) - if err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "authorizeSSHRenew"), - code: http.StatusUnauthorized, - context: errContext, - } - } - cert, err := p.AuthorizeSSHRenew(ctx, token) - if err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "authorizeSSHRenew"), - code: http.StatusUnauthorized, - context: errContext, - } - } - return cert, nil -} - // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) { nonce, err := randutil.ASCII(32) if err != nil { - return nil, &apiError{err: err, code: http.StatusInternalServerError} + return nil, errs.InternalServerError(err) } var serial uint64 if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "renewSSH: error reading random number"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error reading random number") } if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errors.New("rewnewSSH: cannot renew certificate without validity period") + return nil, errs.BadRequest(errors.New("rewnewSSH: cannot renew certificate without validity period")) } backdate := a.config.AuthorityConfig.Backdate.Duration @@ -393,25 +321,16 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) switch cert.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { - return nil, &apiError{ - err: errors.New("renewSSH: user certificate signing is not enabled"), - code: http.StatusNotImplemented, - } + return nil, errs.NotImplemented(errors.New("renewSSH: user certificate signing is not enabled")) } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { - return nil, &apiError{ - err: errors.New("renewSSH: host certificate signing is not enabled"), - code: http.StatusNotImplemented, - } + return nil, errs.NotImplemented(errors.New("renewSSH: host certificate signing is not enabled")) } signer = a.sshCAHostCertSignKey default: - return nil, &apiError{ - err: errors.Errorf("renewSSH: unexpected ssh certificate type: %d", cert.CertType), - code: http.StatusInternalServerError, - } + return nil, errs.InternalServerError(errors.Errorf("renewSSH: unexpected ssh certificate type: %d", cert.CertType)) } cert.SignatureKey = signer.PublicKey() @@ -422,47 +341,17 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) // Sign the certificate sig, err := signer.Sign(rand.Reader, data) if err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "renewSSH: error signing certificate"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error signing certificate") } cert.Signature = sig if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { - return nil, &apiError{ - err: errors.Wrap(err, "renewSSH: error storing certificate in db"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db") } return cert, nil } -// authorizeSSHRekey authorizes an SSH certificate rekey request, by -// validating the contents of an SSHPOP token. -func (a *Authority) authorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) { - errContext := map[string]interface{}{"ott": token} - - p, err := a.authorizeToken(ctx, token) - if err != nil { - return nil, nil, &apiError{ - err: errors.Wrap(err, "authorizeSSHRenew"), - code: http.StatusUnauthorized, - context: errContext, - } - } - cert, opts, err := p.AuthorizeSSHRekey(ctx, token) - if err != nil { - return nil, nil, &apiError{ - err: errors.Wrap(err, "authorizeSSHRekey"), - code: http.StatusUnauthorized, - context: errContext, - } - } - return cert, opts, nil -} - // RekeySSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var validators []provisioner.SSHCertificateValidator @@ -473,28 +362,22 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp case provisioner.SSHCertificateValidator: validators = append(validators, o) default: - return nil, &apiError{ - err: errors.Errorf("rekeySSH: invalid extra option type %T", o), - code: http.StatusInternalServerError, - } + return nil, errs.InternalServerError(errors.Errorf("rekeySSH; invalid extra option type %T", o)) } } nonce, err := randutil.ASCII(32) if err != nil { - return nil, &apiError{err: err, code: http.StatusInternalServerError} + return nil, errs.InternalServerError(err) } var serial uint64 if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "rekeySSH: error reading random number"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error reading random number") } if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errors.New("rekeySSH: cannot rekey certificate without validity period") + return nil, errs.BadRequest(errors.New("rekeySSH; cannot rekey certificate without validity period")) } backdate := a.config.AuthorityConfig.Backdate.Duration @@ -521,25 +404,16 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp switch cert.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { - return nil, &apiError{ - err: errors.New("rekeySSH: user certificate signing is not enabled"), - code: http.StatusNotImplemented, - } + return nil, errs.NotImplemented(errors.New("rekeySSH; user certificate signing is not enabled")) } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { - return nil, &apiError{ - err: errors.New("rekeySSH: host certificate signing is not enabled"), - code: http.StatusNotImplemented, - } + return nil, errs.NotImplemented(errors.New("rekeySSH; host certificate signing is not enabled")) } signer = a.sshCAHostCertSignKey default: - return nil, &apiError{ - err: errors.Errorf("rekeySSH: unexpected ssh certificate type: %d", cert.CertType), - code: http.StatusInternalServerError, - } + return nil, errs.BadRequest(errors.Errorf("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)) } cert.SignatureKey = signer.PublicKey() @@ -547,80 +421,47 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp data := cert.Marshal() data = data[:len(data)-4] - // Sign the certificate + // Sign the certificate. sig, err := signer.Sign(rand.Reader, data) if err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "rekeySSH: error signing certificate"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error signing certificate") } cert.Signature = sig - // User provisioners validators + // Apply validators from provisioner.. for _, v := range validators { if err := v.Valid(cert); err != nil { - return nil, &apiError{err: err, code: http.StatusForbidden} + return nil, errs.Forbidden(err) } } if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { - return nil, &apiError{ - err: errors.Wrap(err, "rekeySSH: error storing certificate in db"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db") } return cert, nil } -// authorizeSSHRevoke authorizes an SSH certificate revoke request, by -// validating the contents of an SSHPOP token. -func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error { - errContext := map[string]interface{}{"ott": token} - - p, err := a.authorizeToken(ctx, token) - if err != nil { - return &apiError{errors.Wrap(err, "authorizeSSHRevoke"), http.StatusUnauthorized, errContext} - } - if err = p.AuthorizeSSHRevoke(ctx, token); err != nil { - return &apiError{errors.Wrap(err, "authorizeSSHRevoke"), http.StatusUnauthorized, errContext} - } - return nil -} - // SignSSHAddUser signs a certificate that provisions a new user in a server. func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { if a.sshCAUserCertSignKey == nil { - return nil, &apiError{ - err: errors.New("signSSHAddUser: user certificate signing is not enabled"), - code: http.StatusNotImplemented, - } + return nil, errs.NotImplemented(errors.New("signSSHAddUser: user certificate signing is not enabled")) } if subject.CertType != ssh.UserCert { - return nil, &apiError{ - err: errors.New("signSSHAddUser: certificate is not a user certificate"), - code: http.StatusForbidden, - } + return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate is not a user certificate")) } if len(subject.ValidPrincipals) != 1 { - return nil, &apiError{ - err: errors.New("signSSHAddUser: certificate does not have only one principal"), - code: http.StatusForbidden, - } + return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate does not have only one principal")) } nonce, err := randutil.ASCII(32) if err != nil { - return nil, &apiError{err: err, code: http.StatusInternalServerError} + return nil, errs.InternalServerError(err) } var serial uint64 if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "signSSHAddUser: error reading random number"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number") } signer := a.sshCAUserCertSignKey @@ -656,10 +497,7 @@ func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) cert.Signature = sig if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { - return nil, &apiError{ - err: errors.Wrap(err, "signSSHAddUser: error storing certificate in db"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db") } return cert, nil @@ -691,14 +529,12 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st // GetSSHHosts returns a list of valid host principals. func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) { if a.sshGetHostsFunc != nil { - return a.sshGetHostsFunc(cert) + hosts, err := a.sshGetHostsFunc(cert) + return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts") } hostnames, err := a.db.GetSSHHostPrincipals() if err != nil { - return nil, &apiError{ - err: errors.Wrap(err, "getSSHHosts"), - code: http.StatusInternalServerError, - } + return nil, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts") } hosts := make([]sshutil.Host, len(hostnames)) diff --git a/authority/ssh_test.go b/authority/ssh_test.go index 9b403132..db5dc85d 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -5,8 +5,10 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/x509" "encoding/base64" "fmt" + "net/http" "reflect" "testing" "time" @@ -15,6 +17,8 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/sshutil" "github.com/smallstep/certificates/templates" "github.com/smallstep/cli/jose" "golang.org/x/crypto/ssh" @@ -498,8 +502,8 @@ func TestAuthority_CheckSSHHost(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) - a.db = &MockAuthDB{ - isSSHHost: func(_ string) (bool, error) { + a.db = &db.MockAuthDB{ + MIsSSHHost: func(_ string) (bool, error) { return tt.fields.exists, tt.fields.err }, } @@ -640,6 +644,9 @@ func TestAuthority_GetSSHBastion(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) return + } else if err != nil { + _, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) @@ -647,3 +654,266 @@ func TestAuthority_GetSSHBastion(t *testing.T) { }) } } + +func TestAuthority_GetSSHHosts(t *testing.T) { + a := testAuthority(t) + + type test struct { + getHostsFunc func(*x509.Certificate) ([]sshutil.Host, error) + auth *Authority + cert *x509.Certificate + cmp func(got []sshutil.Host) + err error + code int + } + tests := map[string]func(t *testing.T) *test{ + "fail/getHostsFunc-fail": func(t *testing.T) *test { + return &test{ + getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) { + return nil, errors.New("force") + }, + cert: &x509.Certificate{}, + err: errors.New("getSSHHosts: force"), + code: http.StatusInternalServerError, + } + }, + "ok/getHostsFunc-defined": func(t *testing.T) *test { + hosts := []sshutil.Host{ + {HostID: "1", Hostname: "foo"}, + {HostID: "2", Hostname: "bar"}, + } + + return &test{ + getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) { + return hosts, nil + }, + cert: &x509.Certificate{}, + cmp: func(got []sshutil.Host) { + assert.Equals(t, got, hosts) + }, + } + }, + "fail/db-get-fail": func(t *testing.T) *test { + return &test{ + auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ + MGetSSHHostPrincipals: func() ([]string, error) { + return nil, errors.New("force") + }, + })), + cert: &x509.Certificate{}, + err: errors.New("getSSHHosts: force"), + code: http.StatusInternalServerError, + } + }, + "ok": func(t *testing.T) *test { + return &test{ + auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ + MGetSSHHostPrincipals: func() ([]string, error) { + return []string{"foo", "bar"}, nil + }, + })), + cert: &x509.Certificate{}, + cmp: func(got []sshutil.Host) { + assert.Equals(t, got, []sshutil.Host{ + {Hostname: "foo"}, + {Hostname: "bar"}, + }) + }, + } + }, + } + for name, genTestCase := range tests { + t.Run(name, func(t *testing.T) { + tc := genTestCase(t) + + auth := tc.auth + if auth == nil { + auth = a + } + auth.sshGetHostsFunc = tc.getHostsFunc + + hosts, err := auth.GetSSHHosts(tc.cert) + if err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + tc.cmp(hosts) + } + } + }) + } +} + +func TestAuthority_RekeySSH(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + pub, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + signer, err := ssh.NewSignerFromKey(signKey) + assert.FatalError(t, err) + + userOptions := sshTestModifier{ + CertType: ssh.UserCert, + } + + now := time.Now().UTC() + + a := testAuthority(t) + + type test struct { + auth *Authority + userSigner ssh.Signer + hostSigner ssh.Signer + cert *ssh.Certificate + key ssh.PublicKey + signOpts []provisioner.SignOption + cmpResult func(old, n *ssh.Certificate) + err error + code int + } + tests := map[string]func(t *testing.T) *test{ + "fail/opts-type": func(t *testing.T) *test { + return &test{ + userSigner: signer, + hostSigner: signer, + key: pub, + signOpts: []provisioner.SignOption{userOptions}, + err: errors.New("rekeySSH; invalid extra option type"), + code: http.StatusInternalServerError, + } + }, + "fail/old-cert-validAfter": func(t *testing.T) *test { + return &test{ + userSigner: signer, + hostSigner: signer, + cert: &ssh.Certificate{}, + key: pub, + signOpts: []provisioner.SignOption{}, + err: errors.New("rekeySSH; cannot rekey certificate without validity period"), + code: http.StatusBadRequest, + } + }, + "fail/old-cert-validBefore": func(t *testing.T) *test { + return &test{ + userSigner: signer, + hostSigner: signer, + cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, + key: pub, + signOpts: []provisioner.SignOption{}, + err: errors.New("rekeySSH; cannot rekey certificate without validity period"), + code: http.StatusBadRequest, + } + }, + "fail/old-cert-no-user-key": func(t *testing.T) *test { + return &test{ + userSigner: nil, + hostSigner: signer, + cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert}, + key: pub, + signOpts: []provisioner.SignOption{}, + err: errors.New("rekeySSH; user certificate signing is not enabled"), + code: http.StatusNotImplemented, + } + }, + "fail/old-cert-no-host-key": func(t *testing.T) *test { + return &test{ + userSigner: signer, + hostSigner: nil, + cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.HostCert}, + key: pub, + signOpts: []provisioner.SignOption{}, + err: errors.New("rekeySSH; host certificate signing is not enabled"), + code: http.StatusNotImplemented, + } + }, + "fail/unexpected-old-cert-type": func(t *testing.T) *test { + return &test{ + userSigner: signer, + hostSigner: signer, + cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, + key: pub, + signOpts: []provisioner.SignOption{}, + err: errors.New("rekeySSH; unexpected ssh certificate type: 0"), + code: http.StatusBadRequest, + } + }, + "fail/db-store": func(t *testing.T) *test { + return &test{ + auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ + MStoreSSHCertificate: func(cert *ssh.Certificate) error { + return errors.New("force") + }, + })), + userSigner: signer, + hostSigner: nil, + cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert}, + key: pub, + signOpts: []provisioner.SignOption{}, + err: errors.New("rekeySSH; error storing certificate in db: force"), + code: http.StatusInternalServerError, + } + }, + "ok": func(t *testing.T) *test { + va1 := now.Add(-24 * time.Hour) + vb1 := now.Add(-23 * time.Hour) + return &test{ + userSigner: signer, + hostSigner: nil, + cert: &ssh.Certificate{ + ValidAfter: uint64(va1.Unix()), + ValidBefore: uint64(vb1.Unix()), + CertType: ssh.UserCert, + ValidPrincipals: []string{"foo", "bar"}, + KeyId: "foo", + }, + key: pub, + signOpts: []provisioner.SignOption{}, + cmpResult: func(old, n *ssh.Certificate) { + assert.Equals(t, n.CertType, old.CertType) + assert.Equals(t, n.ValidPrincipals, old.ValidPrincipals) + assert.Equals(t, n.KeyId, old.KeyId) + + assert.True(t, n.ValidAfter > uint64(now.Add(-5*time.Minute).Unix())) + assert.True(t, n.ValidAfter < uint64(now.Add(5*time.Minute).Unix())) + + l8r := now.Add(1 * time.Hour) + assert.True(t, n.ValidBefore > uint64(l8r.Add(-5*time.Minute).Unix())) + assert.True(t, n.ValidBefore < uint64(l8r.Add(5*time.Minute).Unix())) + }, + } + }, + } + for name, genTestCase := range tests { + t.Run(name, func(t *testing.T) { + tc := genTestCase(t) + + auth := tc.auth + if auth == nil { + auth = a + } + a.sshCAUserCertSignKey = tc.userSigner + a.sshCAHostCertSignKey = tc.hostSigner + + cert, err := auth.RekeySSH(tc.cert, tc.key, tc.signOpts...) + if err != nil { + if assert.NotNil(t, tc.err) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + tc.cmpResult(tc.cert, cert) + } + } + }) + } +} diff --git a/authority/testdata/certs/ssh_host_ca_key.pub b/authority/testdata/certs/ssh_host_ca_key.pub new file mode 100644 index 00000000..aa5685da --- /dev/null +++ b/authority/testdata/certs/ssh_host_ca_key.pub @@ -0,0 +1 @@ +ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY= diff --git a/authority/testdata/certs/ssh_user_ca_key.pub b/authority/testdata/certs/ssh_user_ca_key.pub new file mode 100644 index 00000000..5909ce43 --- /dev/null +++ b/authority/testdata/certs/ssh_user_ca_key.pub @@ -0,0 +1 @@ +ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y= diff --git a/authority/testdata/secrets/ssh_host_ca_key b/authority/testdata/secrets/ssh_host_ca_key new file mode 100644 index 00000000..7a7e4c44 --- /dev/null +++ b/authority/testdata/secrets/ssh_host_ca_key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49 +AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+ +MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg== +-----END EC PRIVATE KEY----- diff --git a/authority/testdata/secrets/ssh_user_ca_key b/authority/testdata/secrets/ssh_user_ca_key new file mode 100644 index 00000000..92d35ec2 --- /dev/null +++ b/authority/testdata/secrets/ssh_user_ca_key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49 +AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h +YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg== +-----END EC PRIVATE KEY----- diff --git a/authority/tls.go b/authority/tls.go index eb7cb86a..9199c040 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/x509util" @@ -60,7 +61,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption { // Sign creates a signed certificate from a certificate signing request. func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { var ( - errContext = apiCtx{"csr": csr, "signOptions": signOpts} + opts = []errs.Option{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)} certValidators = []provisioner.CertificateValidator{} issIdentity = a.intermediateIdentity @@ -75,54 +76,52 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti certValidators = append(certValidators, k) case provisioner.CertificateRequestValidator: if err := k.Valid(csr); err != nil { - return nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext} + return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) } case provisioner.ProfileModifier: mods = append(mods, k.Option(signOpts)) default: - return nil, &apiError{errors.Errorf("sign: invalid extra option type %T", k), - http.StatusInternalServerError, errContext} + return nil, errs.InternalServerError(errors.Errorf("authority.Sign; invalid extra option type %T", k), opts...) } } if err := csr.CheckSignature(); err != nil { - return nil, &apiError{errors.Wrap(err, "sign: invalid certificate request"), - http.StatusBadRequest, errContext} + return nil, errs.Wrap(http.StatusBadRequest, err, "authority.Sign; invalid certificate request", opts...) } leaf, err := x509util.NewLeafProfileWithCSR(csr, issIdentity.Crt, issIdentity.Key, mods...) if err != nil { - return nil, &apiError{errors.Wrapf(err, "sign"), http.StatusInternalServerError, errContext} + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign", opts...) } for _, v := range certValidators { - if err := v.Valid(leaf.Subject()); err != nil { - return nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext} + if err := v.Valid(leaf.Subject(), signOpts); err != nil { + return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) } } crtBytes, err := leaf.CreateCertificate() if err != nil { - return nil, &apiError{errors.Wrap(err, "sign: error creating new leaf certificate"), - http.StatusInternalServerError, errContext} + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.Sign; error creating new leaf certificate", opts...) } serverCert, err := x509.ParseCertificate(crtBytes) if err != nil { - return nil, &apiError{errors.Wrap(err, "sign: error parsing new leaf certificate"), - http.StatusInternalServerError, errContext} + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.Sign; error parsing new leaf certificate", opts...) } caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw) if err != nil { - return nil, &apiError{errors.Wrap(err, "sign: error parsing intermediate certificate"), - http.StatusInternalServerError, errContext} + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.Sign; error parsing intermediate certificate", opts...) } if err = a.db.StoreCertificate(serverCert); err != nil { if err != db.ErrNotImplemented { - return nil, &apiError{errors.Wrap(err, "sign: error storing certificate in db"), - http.StatusInternalServerError, errContext} + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.Sign; error storing certificate in db", opts...) } } @@ -132,9 +131,11 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti // Renew creates a new Certificate identical to the old certificate, except // with a validity window that begins 'now'. func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { + opts := []errs.Option{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())} + // Check step provisioner extensions if err := a.authorizeRenew(oldCert); err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew", opts...) } // Issuer @@ -161,16 +162,16 @@ func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error MaxPathLenZero: oldCert.MaxPathLenZero, OCSPServer: oldCert.OCSPServer, IssuingCertificateURL: oldCert.IssuingCertificateURL, + PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical, + PermittedEmailAddresses: oldCert.PermittedEmailAddresses, DNSNames: oldCert.DNSNames, EmailAddresses: oldCert.EmailAddresses, IPAddresses: oldCert.IPAddresses, URIs: oldCert.URIs, - PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical, PermittedDNSDomains: oldCert.PermittedDNSDomains, ExcludedDNSDomains: oldCert.ExcludedDNSDomains, PermittedIPRanges: oldCert.PermittedIPRanges, ExcludedIPRanges: oldCert.ExcludedIPRanges, - PermittedEmailAddresses: oldCert.PermittedEmailAddresses, ExcludedEmailAddresses: oldCert.ExcludedEmailAddresses, PermittedURIDomains: oldCert.PermittedURIDomains, ExcludedURIDomains: oldCert.ExcludedURIDomains, @@ -190,29 +191,28 @@ func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error leaf, err := x509util.NewLeafProfileWithTemplate(newCert, issIdentity.Crt, issIdentity.Key) if err != nil { - return nil, &apiError{err, http.StatusInternalServerError, apiCtx{}} + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew", opts...) } crtBytes, err := leaf.CreateCertificate() if err != nil { - return nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"), - http.StatusInternalServerError, apiCtx{}} + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.Renew; error renewing certificate from existing server certificate", opts...) } serverCert, err := x509.ParseCertificate(crtBytes) if err != nil { - return nil, &apiError{errors.Wrap(err, "error parsing new server certificate"), - http.StatusInternalServerError, apiCtx{}} + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.Renew; error parsing new server certificate", opts...) } caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw) if err != nil { - return nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"), - http.StatusInternalServerError, apiCtx{}} + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.Renew; error parsing intermediate certificate", opts...) } if err = a.db.StoreCertificate(serverCert); err != nil { if err != db.ErrNotImplemented { - return nil, &apiError{errors.Wrap(err, "error storing certificate in db"), - http.StatusInternalServerError, apiCtx{}} + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Renew; error storing certificate in db", opts...) } } @@ -236,26 +236,26 @@ type RevokeOptions struct { // being renewed. // // TODO: Add OCSP and CRL support. -func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error { - errContext := apiCtx{ - "serialNumber": opts.Serial, - "reasonCode": opts.ReasonCode, - "reason": opts.Reason, - "passiveOnly": opts.PassiveOnly, - "mTLS": opts.MTLS, - "context": string(provisioner.MethodFromContext(ctx)), +func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error { + opts := []errs.Option{ + errs.WithKeyVal("serialNumber", revokeOpts.Serial), + errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode), + errs.WithKeyVal("reason", revokeOpts.Reason), + errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly), + errs.WithKeyVal("MTLS", revokeOpts.MTLS), + errs.WithKeyVal("context", string(provisioner.MethodFromContext(ctx))), } - if opts.MTLS { - errContext["certificate"] = base64.StdEncoding.EncodeToString(opts.Crt.Raw) + if revokeOpts.MTLS { + opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw))) } else { - errContext["ott"] = opts.OTT + opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT)) } rci := &db.RevokedCertificateInfo{ - Serial: opts.Serial, - ReasonCode: opts.ReasonCode, - Reason: opts.Reason, - MTLS: opts.MTLS, + Serial: revokeOpts.Serial, + ReasonCode: revokeOpts.ReasonCode, + Reason: revokeOpts.Reason, + MTLS: revokeOpts.MTLS, RevokedAt: time.Now().UTC(), } @@ -264,48 +264,43 @@ func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error { err error ) // If not mTLS then get the TokenID of the token. - if !opts.MTLS { - // Validate payload - token, err := jose.ParseSigned(opts.OTT) + if !revokeOpts.MTLS { + token, err := jose.ParseSigned(revokeOpts.OTT) if err != nil { - return &apiError{errors.Wrapf(err, "revoke: error parsing token"), - http.StatusUnauthorized, errContext} + return errs.Wrap(http.StatusUnauthorized, err, + "authority.Revoke; error parsing token", opts...) } - // Get claims w/out verification. We should have already verified this token - // earlier with a call to authorizeSSHRevoke. + // Get claims w/out verification. var claims Claims if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { - return &apiError{errors.Wrap(err, "revoke"), http.StatusUnauthorized, errContext} + return errs.Wrap(http.StatusUnauthorized, err, "authority.Revoke", opts...) } // This method will also validate the audiences for JWK provisioners. var ok bool p, ok = a.provisioners.LoadByToken(token, &claims.Claims) if !ok { - return &apiError{ - errors.Errorf("revoke: provisioner not found"), - http.StatusInternalServerError, errContext} + return errs.InternalServerError(errors.Errorf("authority.Revoke; provisioner not found"), opts...) } - rci.TokenID, err = p.GetTokenID(opts.OTT) + rci.TokenID, err = p.GetTokenID(revokeOpts.OTT) if err != nil { - return &apiError{errors.Wrap(err, "revoke: could not get ID for token"), - http.StatusInternalServerError, errContext} + return errs.Wrap(http.StatusInternalServerError, err, + "authority.Revoke; could not get ID for token") } - errContext["tokenID"] = rci.TokenID + opts = append(opts, errs.WithKeyVal("tokenID", rci.TokenID)) } else { // Load the Certificate provisioner if one exists. - p, err = a.LoadProvisionerByCertificate(opts.Crt) + p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt) if err != nil { - return &apiError{ - errors.Wrap(err, "revoke: unable to load certificate provisioner"), - http.StatusUnauthorized, errContext} + return errs.Wrap(http.StatusUnauthorized, err, + "authority.Revoke: unable to load certificate provisioner", opts...) } } rci.ProvisionerID = p.GetID() - errContext["provisionerID"] = rci.ProvisionerID + opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID)) - if provisioner.MethodFromContext(ctx) == provisioner.RevokeSSHMethod { + if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod { err = a.db.RevokeSSH(rci) } else { // default to revoke x509 err = a.db.Revoke(rci) @@ -314,13 +309,12 @@ func (a *Authority) Revoke(ctx context.Context, opts *RevokeOptions) error { case nil: return nil case db.ErrNotImplemented: - return &apiError{errors.New("revoke: no persistence layer configured"), - http.StatusNotImplemented, errContext} + return errs.NotImplemented(errors.New("authority.Revoke; no persistence layer configured"), opts...) case db.ErrAlreadyExists: - return &apiError{errors.Errorf("revoke: certificate with serial number %s has already been revoked", rci.Serial), - http.StatusBadRequest, errContext} + return errs.BadRequest(errors.Errorf("authority.Revoke; certificate with serial "+ + "number %s has already been revoked", rci.Serial), opts...) default: - return &apiError{err, http.StatusInternalServerError, errContext} + return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) } } @@ -330,17 +324,17 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) { a.intermediateIdentity.Crt, a.intermediateIdentity.Key, x509util.WithHosts(strings.Join(a.config.DNSNames, ","))) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate") } crtBytes, err := profile.CreateCertificate() if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate") } keyPEM, err := pemutil.Serialize(profile.SubjectPrivateKey()) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate") } crtPEM := pem.EncodeToMemory(&pem.Block{ @@ -352,19 +346,21 @@ func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) { // to a tls.Certificate. intermediatePEM, err := pemutil.Serialize(a.intermediateIdentity.Crt) if err != nil { - return nil, err + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate") } tlsCrt, err := tls.X509KeyPair(append(crtPEM, pem.EncodeToMemory(intermediatePEM)...), pem.EncodeToMemory(keyPEM)) if err != nil { - return nil, errors.Wrap(err, "error creating tls certificate") + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.GetTLSCertificate; error creating tls certificate") } // Get the 'leaf' certificate and set the attribute accordingly. leaf, err := x509.ParseCertificate(tlsCrt.Certificate[0]) if err != nil { - return nil, errors.Wrap(err, "error parsing tls certificate") + return nil, errs.Wrap(http.StatusInternalServerError, err, + "authority.GetTLSCertificate; error parsing tls certificate") } tlsCrt.Leaf = leaf diff --git a/authority/tls_test.go b/authority/tls_test.go index c5c7f8c1..3fbd21bf 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -7,7 +7,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/asn1" - "encoding/base64" "encoding/pem" "fmt" "net/http" @@ -19,6 +18,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/tlsutil" @@ -77,7 +77,7 @@ func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateReques return csr } -func TestSign(t *testing.T) { +func TestAuthority_Sign(t *testing.T) { pub, priv, err := keys.GenerateDefaultKeyPair() assert.FatalError(t, err) @@ -102,7 +102,7 @@ func TestSign(t *testing.T) { p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK) key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) - token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key) + token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key) assert.FatalError(t, err) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) extraOpts, err := a.Authorize(ctx, token) @@ -113,7 +113,8 @@ func TestSign(t *testing.T) { csr *x509.CertificateRequest signOpts provisioner.Options extraOpts []provisioner.SignOption - err *apiError + err error + code int } tests := map[string]func(*testing.T) *signTest{ "fail invalid signature": func(t *testing.T) *signTest { @@ -124,10 +125,8 @@ func TestSign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: &apiError{errors.New("sign: invalid certificate request"), - http.StatusBadRequest, - apiCtx{"csr": csr, "signOptions": signOpts}, - }, + err: errors.New("authority.Sign; invalid certificate request"), + code: http.StatusBadRequest, } }, "fail invalid extra option": func(t *testing.T) *signTest { @@ -138,10 +137,8 @@ func TestSign(t *testing.T) { csr: csr, extraOpts: append(extraOpts, "42"), signOpts: signOpts, - err: &apiError{errors.New("sign: invalid extra option type string"), - http.StatusInternalServerError, - apiCtx{"csr": csr, "signOptions": signOpts}, - }, + err: errors.New("authority.Sign; invalid extra option type string"), + code: http.StatusInternalServerError, } }, "fail merge default ASN1DN": func(t *testing.T) *signTest { @@ -153,10 +150,8 @@ func TestSign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"), - http.StatusInternalServerError, - apiCtx{"csr": csr, "signOptions": signOpts}, - }, + err: errors.New("authority.Sign: default ASN1DN template cannot be nil"), + code: http.StatusInternalServerError, } }, "fail create cert": func(t *testing.T) *signTest { @@ -168,10 +163,8 @@ func TestSign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: &apiError{errors.New("sign: error creating new leaf certificate"), - http.StatusInternalServerError, - apiCtx{"csr": csr, "signOptions": signOpts}, - }, + err: errors.New("authority.Sign; error creating new leaf certificate"), + code: http.StatusInternalServerError, } }, "fail provisioner duration claim": func(t *testing.T) *signTest { @@ -185,10 +178,8 @@ func TestSign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: _signOpts, - err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"), - http.StatusUnauthorized, - apiCtx{"csr": csr, "signOptions": _signOpts}, - }, + err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"), + code: http.StatusUnauthorized, } }, "fail validate sans when adding common name not in claims": func(t *testing.T) *signTest { @@ -200,10 +191,8 @@ func TestSign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), - http.StatusUnauthorized, - apiCtx{"csr": csr, "signOptions": signOpts}, - }, + err: errors.New("authority.Sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), + code: http.StatusUnauthorized, } }, "fail rsa key too short": func(t *testing.T) *signTest { @@ -228,20 +217,16 @@ ZYtQ9Ot36qc= csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: &apiError{errors.New("sign: rsa key in CSR must be at least 2048 bits (256 bytes)"), - http.StatusUnauthorized, - apiCtx{"csr": csr, "signOptions": signOpts}, - }, + err: errors.New("authority.Sign: rsa key in CSR must be at least 2048 bits (256 bytes)"), + code: http.StatusUnauthorized, } }, "fail store cert in db": func(t *testing.T) *signTest { csr := getCSR(t, priv) _a := testAuthority(t) - _a.db = &MockAuthDB{ - storeCertificate: func(crt *x509.Certificate) error { - return &apiError{errors.New("force"), - http.StatusInternalServerError, - apiCtx{"csr": csr, "signOptions": signOpts}} + _a.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { + return errors.New("force") }, } return &signTest{ @@ -249,17 +234,15 @@ ZYtQ9Ot36qc= csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: &apiError{errors.New("sign: error storing certificate in db: force"), - http.StatusInternalServerError, - apiCtx{"csr": csr, "signOptions": signOpts}, - }, + err: errors.New("authority.Sign; error storing certificate in db: force"), + code: http.StatusInternalServerError, } }, "ok": func(t *testing.T) *signTest { csr := getCSR(t, priv) _a := testAuthority(t) - _a.db = &MockAuthDB{ - storeCertificate: func(crt *x509.Certificate) error { + _a.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { assert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, @@ -279,15 +262,17 @@ ZYtQ9Ot36qc= certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...) if err != nil { - if assert.NotNil(t, tc.err) { - switch v := err.(type) { - case *apiError: - assert.HasPrefix(t, v.err.Error(), tc.err.Error()) - assert.Equals(t, v.code, tc.err.code) - assert.Equals(t, v.context, tc.err.context) - default: - t.Errorf("unexpected error type: %T", v) - } + if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { + assert.Nil(t, certChain) + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + + ctxErr, ok := err.(*errs.Error) + assert.Fatal(t, ok, "error is not of type *errs.Error") + assert.Equals(t, ctxErr.Details["csr"], tc.csr) + assert.Equals(t, ctxErr.Details["signOptions"], tc.signOpts) } } else { leaf := certChain[0] @@ -346,7 +331,7 @@ ZYtQ9Ot36qc= } } -func TestRenew(t *testing.T) { +func TestAuthority_Renew(t *testing.T) { pub, _, err := keys.GenerateDefaultKeyPair() assert.FatalError(t, err) @@ -375,9 +360,9 @@ func TestRenew(t *testing.T) { x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"), withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID)) assert.FatalError(t, err) - crtBytes, err := leaf.CreateCertificate() + certBytes, err := leaf.CreateCertificate() assert.FatalError(t, err) - crt, err := x509.ParseCertificate(crtBytes) + cert, err := x509.ParseCertificate(certBytes) assert.FatalError(t, err) leafNoRenew, err := x509util.NewLeafProfile("norenew", a.intermediateIdentity.Crt, @@ -388,15 +373,16 @@ func TestRenew(t *testing.T) { withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID), ) assert.FatalError(t, err) - crtBytesNoRenew, err := leafNoRenew.CreateCertificate() + certBytesNoRenew, err := leafNoRenew.CreateCertificate() assert.FatalError(t, err) - crtNoRenew, err := x509.ParseCertificate(crtBytesNoRenew) + certNoRenew, err := x509.ParseCertificate(certBytesNoRenew) assert.FatalError(t, err) type renewTest struct { auth *Authority - crt *x509.Certificate - err *apiError + cert *x509.Certificate + err error + code int } tests := map[string]func() (*renewTest, error){ "fail-create-cert": func() (*renewTest, error) { @@ -404,25 +390,22 @@ func TestRenew(t *testing.T) { _a.intermediateIdentity.Key = nil return &renewTest{ auth: _a, - crt: crt, - err: &apiError{errors.New("error renewing certificate from existing server certificate"), - http.StatusInternalServerError, apiCtx{}}, + cert: cert, + err: errors.New("authority.Renew; error renewing certificate from existing server certificate"), + code: http.StatusInternalServerError, }, nil }, "fail-unauthorized": func() (*renewTest, error) { - ctx := map[string]interface{}{ - "serialNumber": crtNoRenew.SerialNumber.String(), - } return &renewTest{ - crt: crtNoRenew, - err: &apiError{errors.New("renew: renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), - http.StatusUnauthorized, ctx}, + cert: certNoRenew, + err: errors.New("authority.Renew: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), + code: http.StatusUnauthorized, }, nil }, "success": func() (*renewTest, error) { return &renewTest{ auth: a, - crt: crt, + cert: cert, }, nil }, "success-new-intermediate": func() (*renewTest, error) { @@ -430,23 +413,23 @@ func TestRenew(t *testing.T) { assert.FatalError(t, err) newRootBytes, err := newRootProfile.CreateCertificate() assert.FatalError(t, err) - newRootCrt, err := x509.ParseCertificate(newRootBytes) + newRootCert, err := x509.ParseCertificate(newRootBytes) assert.FatalError(t, err) newIntermediateProfile, err := x509util.NewIntermediateProfile("new-intermediate", - newRootCrt, newRootProfile.SubjectPrivateKey()) + newRootCert, newRootProfile.SubjectPrivateKey()) assert.FatalError(t, err) newIntermediateBytes, err := newIntermediateProfile.CreateCertificate() assert.FatalError(t, err) - newIntermediateCrt, err := x509.ParseCertificate(newIntermediateBytes) + newIntermediateCert, err := x509.ParseCertificate(newIntermediateBytes) assert.FatalError(t, err) _a := testAuthority(t) _a.intermediateIdentity.Key = newIntermediateProfile.SubjectPrivateKey() - _a.intermediateIdentity.Crt = newIntermediateCrt + _a.intermediateIdentity.Crt = newIntermediateCert return &renewTest{ auth: _a, - crt: crt, + cert: cert, }, nil }, } @@ -458,32 +441,33 @@ func TestRenew(t *testing.T) { var certChain []*x509.Certificate if tc.auth != nil { - certChain, err = tc.auth.Renew(tc.crt) + certChain, err = tc.auth.Renew(tc.cert) } else { - certChain, err = a.Renew(tc.crt) + certChain, err = a.Renew(tc.cert) } if err != nil { - if assert.NotNil(t, tc.err) { - switch v := err.(type) { - case *apiError: - assert.HasPrefix(t, v.err.Error(), tc.err.Error()) - assert.Equals(t, v.code, tc.err.code) - assert.Equals(t, v.context, tc.err.context) - default: - t.Errorf("unexpected error type: %T", v) - } + if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { + assert.Nil(t, certChain) + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + + ctxErr, ok := err.(*errs.Error) + assert.Fatal(t, ok, "error is not of type *errs.Error") + assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String()) } } else { leaf := certChain[0] intermediate := certChain[1] if assert.Nil(t, tc.err) { - assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.crt.NotAfter.Sub(crt.NotBefore)) + assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore)) - assert.True(t, leaf.NotBefore.After(now.Add(-time.Minute))) + assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute))) assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute))) expiry := now.Add(time.Minute * 7) - assert.True(t, leaf.NotAfter.After(expiry.Add(-time.Minute))) + assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) tmplt := a.config.AuthorityConfig.Template @@ -513,7 +497,7 @@ func TestRenew(t *testing.T) { if a.intermediateIdentity.Crt.SerialNumber == tc.auth.intermediateIdentity.Crt.SerialNumber { assert.Equals(t, leaf.AuthorityKeyId, a.intermediateIdentity.Crt.SubjectKeyId) // Compare extensions: they can be in a different order - for _, ext1 := range tc.crt.Extensions { + for _, ext1 := range tc.cert.Extensions { found := false for _, ext2 := range leaf.Extensions { if reflect.DeepEqual(ext1, ext2) { @@ -529,7 +513,7 @@ func TestRenew(t *testing.T) { // We did change the intermediate before renewing. assert.Equals(t, leaf.AuthorityKeyId, tc.auth.intermediateIdentity.Crt.SubjectKeyId) // Compare extensions: they can be in a different order - for _, ext1 := range tc.crt.Extensions { + for _, ext1 := range tc.cert.Extensions { // The authority key id extension should be different b/c the intermediates are different. if ext1.Id.Equal(oidAuthorityKeyIdentifier) { for _, ext2 := range leaf.Extensions { @@ -560,7 +544,7 @@ func TestRenew(t *testing.T) { } } -func TestGetTLSOptions(t *testing.T) { +func TestAuthority_GetTLSOptions(t *testing.T) { type renewTest struct { auth *Authority opts *tlsutil.TLSOptions @@ -596,21 +580,12 @@ func TestGetTLSOptions(t *testing.T) { } } -func TestRevoke(t *testing.T) { +func TestAuthority_Revoke(t *testing.T) { reasonCode := 2 reason := "bob was let go" validIssuer := "step-cli" - validAudience := []string{"https://test.ca.smallstep.com/revoke"} + validAudience := testAudiences.Revoke now := time.Now().UTC() - getCtx := func() map[string]interface{} { - return apiCtx{ - "serialNumber": "sn", - "reasonCode": reasonCode, - "reason": reason, - "mTLS": false, - "passiveOnly": false, - } - } jwk, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) @@ -619,30 +594,30 @@ func TestRevoke(t *testing.T) { (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) + a := testAuthority(t) + type test struct { - a *Authority - opts *RevokeOptions - err *apiError + auth *Authority + opts *RevokeOptions + err error + code int + checkErrDetails func(err *errs.Error) } tests := map[string]func() test{ - "error/token/authorizeRevoke error": func() test { - a := testAuthority(t) - ctx := getCtx() - ctx["ott"] = "foo" + "fail/token/authorizeRevoke error": func() test { return test{ - a: a, + auth: a, opts: &RevokeOptions{ OTT: "foo", Serial: "sn", ReasonCode: reasonCode, Reason: reason, }, - err: &apiError{errors.New("revoke: authorizeRevoke: authorizeToken: error parsing token"), - http.StatusUnauthorized, ctx}, + err: errors.New("authority.Revoke; error parsing token"), + code: http.StatusUnauthorized, } }, - "error/nil-db": func() test { - a := testAuthority(t) + "fail/nil-db": func() test { cl := jwt.Claims{ Subject: "sn", Issuer: validIssuer, @@ -654,30 +629,30 @@ func TestRevoke(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) - ctx := getCtx() - ctx["ott"] = raw - ctx["tokenID"] = "44" - ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc" return test{ - a: a, + auth: a, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, Reason: reason, OTT: raw, }, - err: &apiError{errors.New("revoke: no persistence layer configured"), - http.StatusNotImplemented, ctx}, + err: errors.New("authority.Revoke; no persistence layer configured"), + code: http.StatusNotImplemented, + checkErrDetails: func(err *errs.Error) { + assert.Equals(t, err.Details["token"], raw) + assert.Equals(t, err.Details["tokenID"], "44") + assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") + }, } }, - "error/db-revoke": func() test { - a := testAuthority(t) - a.db = &MockAuthDB{ - useToken: func(id, tok string) (bool, error) { + "fail/db-revoke": func() test { + _a := testAuthority(t, WithDatabase(&db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { return true, nil }, - err: errors.New("force"), - } + Err: errors.New("force"), + })) cl := jwt.Claims{ Subject: "sn", @@ -690,30 +665,30 @@ func TestRevoke(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) - ctx := getCtx() - ctx["ott"] = raw - ctx["tokenID"] = "44" - ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc" return test{ - a: a, + auth: _a, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, Reason: reason, OTT: raw, }, - err: &apiError{errors.New("force"), - http.StatusInternalServerError, ctx}, + err: errors.New("authority.Revoke: force"), + code: http.StatusInternalServerError, + checkErrDetails: func(err *errs.Error) { + assert.Equals(t, err.Details["token"], raw) + assert.Equals(t, err.Details["tokenID"], "44") + assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") + }, } }, - "error/already-revoked": func() test { - a := testAuthority(t) - a.db = &MockAuthDB{ - useToken: func(id, tok string) (bool, error) { + "fail/already-revoked": func() test { + _a := testAuthority(t, WithDatabase(&db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { return true, nil }, - err: db.ErrAlreadyExists, - } + Err: db.ErrAlreadyExists, + })) cl := jwt.Claims{ Subject: "sn", @@ -726,29 +701,29 @@ func TestRevoke(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) - ctx := getCtx() - ctx["ott"] = raw - ctx["tokenID"] = "44" - ctx["provisionerID"] = "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc" return test{ - a: a, + auth: _a, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, Reason: reason, OTT: raw, }, - err: &apiError{errors.New("revoke: certificate with serial number sn has already been revoked"), - http.StatusBadRequest, ctx}, + err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"), + code: http.StatusBadRequest, + checkErrDetails: func(err *errs.Error) { + assert.Equals(t, err.Details["token"], raw) + assert.Equals(t, err.Details["tokenID"], "44") + assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") + }, } }, "ok/token": func() test { - a := testAuthority(t) - a.db = &MockAuthDB{ - useToken: func(id, tok string) (bool, error) { + _a := testAuthority(t, WithDatabase(&db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { return true, nil }, - } + })) cl := jwt.Claims{ Subject: "sn", @@ -761,7 +736,7 @@ func TestRevoke(t *testing.T) { raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return test{ - a: a, + auth: _a, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, @@ -770,39 +745,14 @@ func TestRevoke(t *testing.T) { }, } }, - "error/mTLS/authorizeRevoke": func() test { - a := testAuthority(t) - a.db = &MockAuthDB{} - - crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") - assert.FatalError(t, err) - - ctx := getCtx() - ctx["certificate"] = base64.StdEncoding.EncodeToString(crt.Raw) - ctx["mTLS"] = true - - return test{ - a: a, - opts: &RevokeOptions{ - Crt: crt, - Serial: "sn", - ReasonCode: reasonCode, - Reason: reason, - MTLS: true, - }, - err: &apiError{errors.New("revoke: authorizeRevoke: serial number in certificate different than body"), - http.StatusUnauthorized, ctx}, - } - }, "ok/mTLS": func() test { - a := testAuthority(t) - a.db = &MockAuthDB{} + _a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") assert.FatalError(t, err) return test{ - a: a, + auth: _a, opts: &RevokeOptions{ Crt: crt, Serial: "102012593071130646873265215610956555026", @@ -816,15 +766,24 @@ func TestRevoke(t *testing.T) { for name, f := range tests { tc := f() t.Run(name, func(t *testing.T) { - if err := tc.a.Revoke(context.TODO(), tc.opts); err != nil { - if assert.NotNil(t, tc.err) { - switch v := err.(type) { - case *apiError: - assert.HasPrefix(t, v.err.Error(), tc.err.Error()) - assert.Equals(t, v.code, tc.err.code) - assert.Equals(t, v.context, tc.err.context) - default: - t.Errorf("unexpected error type: %T", v) + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) + if err := tc.auth.Revoke(ctx, tc.opts); err != nil { + if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) + + ctxErr, ok := err.(*errs.Error) + assert.Fatal(t, ok, "error is not of type *errs.Error") + assert.Equals(t, ctxErr.Details["serialNumber"], tc.opts.Serial) + assert.Equals(t, ctxErr.Details["reasonCode"], tc.opts.ReasonCode) + assert.Equals(t, ctxErr.Details["reason"], tc.opts.Reason) + assert.Equals(t, ctxErr.Details["MTLS"], tc.opts.MTLS) + assert.Equals(t, ctxErr.Details["context"], string(provisioner.RevokeMethod)) + + if tc.checkErrDetails != nil { + tc.checkErrDetails(ctxErr) } } } else { diff --git a/ca/ca_test.go b/ca/ca_test.go index ef00132c..4a04756d 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -22,6 +22,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/randutil" @@ -102,7 +103,7 @@ func TestCASign(t *testing.T) { ca: ca, body: "invalid json", status: http.StatusBadRequest, - errMsg: "Bad Request", + errMsg: errs.BadRequestDefaultMsg, } }, "fail invalid-csr-sig": func(t *testing.T) *signTest { @@ -140,7 +141,7 @@ ZEp7knvU2psWRw== ca: ca, body: string(body), status: http.StatusBadRequest, - errMsg: "Bad Request", + errMsg: errs.BadRequestDefaultMsg, } }, "fail unauthorized-ott": func(t *testing.T) *signTest { @@ -155,7 +156,7 @@ ZEp7knvU2psWRw== ca: ca, body: string(body), status: http.StatusUnauthorized, - errMsg: "Unauthorized", + errMsg: errs.UnauthorizedDefaultMsg, } }, "fail commonname-claim": func(t *testing.T) *signTest { @@ -188,7 +189,7 @@ ZEp7knvU2psWRw== ca: ca, body: string(body), status: http.StatusUnauthorized, - errMsg: "Unauthorized", + errMsg: errs.UnauthorizedDefaultMsg, } }, "ok": func(t *testing.T) *signTest { @@ -392,7 +393,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) { ca: ca, kid: "foo", status: http.StatusNotFound, - errMsg: "Not Found", + errMsg: errs.NotFoundDefaultMsg, } }, "ok": func(t *testing.T) *ekt { @@ -455,7 +456,7 @@ func TestCARoot(t *testing.T) { ca: ca, sha: "foo", status: http.StatusNotFound, - errMsg: "Not Found", + errMsg: errs.NotFoundDefaultMsg, } }, "success": func(t *testing.T) *rootTest { @@ -575,7 +576,7 @@ func TestCARenew(t *testing.T) { ca: ca, tlsConnState: nil, status: http.StatusBadRequest, - errMsg: "Bad Request", + errMsg: errs.BadRequestDefaultMsg, } }, "request-missing-peer-certificate": func(t *testing.T) *renewTest { @@ -583,7 +584,7 @@ func TestCARenew(t *testing.T) { ca: ca, tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}}, status: http.StatusBadRequest, - errMsg: "Bad Request", + errMsg: errs.BadRequestDefaultMsg, } }, "success": func(t *testing.T) *renewTest { diff --git a/ca/client.go b/ca/client.go index 051bac5b..e6fdab92 100644 --- a/ca/client.go +++ b/ca/client.go @@ -486,7 +486,7 @@ func (c *Client) Version() (*api.VersionResponse, error) { retry: resp, err := c.client.Get(u.String()) if err != nil { - return nil, errors.Wrapf(err, "client GET %s failed", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; client GET %s failed", u) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { @@ -497,7 +497,7 @@ retry: } var version api.VersionResponse if err := readJSON(resp.Body, &version); err != nil { - return nil, errors.Wrapf(err, "error reading %s", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; error reading %s", u) } return &version, nil } @@ -510,7 +510,7 @@ func (c *Client) Health() (*api.HealthResponse, error) { retry: resp, err := c.client.Get(u.String()) if err != nil { - return nil, errors.Wrapf(err, "client GET %s failed", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; client GET %s failed", u) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { @@ -521,7 +521,7 @@ retry: } var health api.HealthResponse if err := readJSON(resp.Body, &health); err != nil { - return nil, errors.Wrapf(err, "error reading %s", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; error reading %s", u) } return &health, nil } @@ -537,7 +537,7 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { retry: resp, err := newInsecureClient().Get(u.String()) if err != nil { - return nil, errors.Wrapf(err, "client GET %s failed", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; client GET %s failed", u) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { @@ -548,12 +548,12 @@ retry: } var root api.RootResponse if err := readJSON(resp.Body, &root); err != nil { - return nil, errors.Wrapf(err, "error reading %s", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; error reading %s", u) } // verify the sha256 sum := sha256.Sum256(root.RootPEM.Raw) if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) { - return nil, errors.New("root certificate SHA256 fingerprint do not match") + return nil, errs.BadRequest(errors.New("client.Root; root certificate SHA256 fingerprint do not match")) } return &root, nil } @@ -564,13 +564,13 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { - return nil, errors.Wrap(err, "error marshaling request") + return nil, errs.Wrap(http.StatusInternalServerError, err, "client.Sign; error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"}) retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { - return nil, errors.Wrapf(err, "client POST %s failed", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; client POST %s failed", u) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { @@ -581,7 +581,7 @@ retry: } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { - return nil, errors.Wrapf(err, "error reading %s", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; error reading %s", u) } // Add tls.ConnectionState: // We'll extract the root certificate from the verified chains @@ -598,7 +598,7 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { retry: resp, err := client.Post(u.String(), "application/json", http.NoBody) if err != nil { - return nil, errors.Wrapf(err, "client POST %s failed", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; client POST %s failed", u) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { @@ -609,7 +609,7 @@ retry: } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { - return nil, errors.Wrapf(err, "error reading %s", u) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; error reading %s", u) } return &sign, nil } @@ -1008,13 +1008,13 @@ func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse var retried bool body, err := json.Marshal(req) if err != nil { - return nil, errors.Wrap(err, "error marshaling request") + return nil, errors.Wrap(err, "client.SSHBastion; error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"}) retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { - return nil, errors.Wrapf(err, "client POST %s failed", u) + return nil, errors.Wrapf(err, "client.SSHBastion; client POST %s failed", u) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { @@ -1025,7 +1025,7 @@ retry: } var bastion api.SSHBastionResponse if err := readJSON(resp.Body, &bastion); err != nil { - return nil, errors.Wrapf(err, "error reading %s", u) + return nil, errors.Wrapf(err, "client.SSHBastion; error reading %s", u) } return &bastion, nil } diff --git a/ca/client_test.go b/ca/client_test.go index c2e0063e..5b74f5cb 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -16,12 +16,12 @@ import ( "testing" "time" - "github.com/smallstep/certificates/errs" - + "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/x509util" "golang.org/x/crypto/ssh" ) @@ -154,18 +154,17 @@ func equalJSON(t *testing.T, a interface{}, b interface{}) bool { func TestClient_Version(t *testing.T) { ok := &api.VersionResponse{Version: "test"} - internal := errs.InternalServerError(fmt.Errorf("Internal Server Error")) - notFound := errs.NotFound(fmt.Errorf("Not Found")) tests := []struct { name string response interface{} responseCode int wantErr bool + expectedErr error }{ - {"ok", ok, 200, false}, - {"500", internal, 500, true}, - {"404", notFound, 404, true}, + {"ok", ok, 200, false, nil}, + {"500", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, + {"404", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -185,7 +184,6 @@ func TestClient_Version(t *testing.T) { got, err := c.Version() if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr) return } @@ -195,9 +193,7 @@ func TestClient_Version(t *testing.T) { if got != nil { t.Errorf("Client.Version() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Version() error = %v, want %v", err, tt.response) - } + assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Version() = %v, want %v", got, tt.response) @@ -209,16 +205,16 @@ func TestClient_Version(t *testing.T) { func TestClient_Health(t *testing.T) { ok := &api.HealthResponse{Status: "ok"} - nok := errs.InternalServerError(fmt.Errorf("Internal Server Error")) tests := []struct { name string response interface{} responseCode int wantErr bool + expectedErr error }{ - {"ok", ok, 200, false}, - {"not ok", nok, 500, true}, + {"ok", ok, 200, false, nil}, + {"not ok", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -248,9 +244,7 @@ func TestClient_Health(t *testing.T) { if got != nil { t.Errorf("Client.Health() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Health() error = %v, want %v", err, tt.response) - } + assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Health() = %v, want %v", got, tt.response) @@ -264,7 +258,6 @@ func TestClient_Root(t *testing.T) { ok := &api.RootResponse{ RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, } - notFound := errs.NotFound(fmt.Errorf("Not Found")) tests := []struct { name string @@ -272,9 +265,10 @@ func TestClient_Root(t *testing.T) { response interface{} responseCode int wantErr bool + expectedErr error }{ - {"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false}, - {"not found", "invalid", notFound, 404, true}, + {"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil}, + {"not found", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -307,9 +301,7 @@ func TestClient_Root(t *testing.T) { if got != nil { t.Errorf("Client.Root() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Root() error = %v, want %v", err, tt.response) - } + assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Root() = %v, want %v", got, tt.response) @@ -334,8 +326,6 @@ func TestClient_Sign(t *testing.T) { NotBefore: api.NewTimeDuration(time.Now()), NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)), } - unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized")) - badRequest := errs.BadRequest(fmt.Errorf("Bad Request")) tests := []struct { name string @@ -343,11 +333,12 @@ func TestClient_Sign(t *testing.T) { response interface{} responseCode int wantErr bool + expectedErr error }{ - {"ok", request, ok, 200, false}, - {"unauthorized", request, unauthorized, 401, true}, - {"empty request", &api.SignRequest{}, badRequest, 403, true}, - {"nil request", nil, badRequest, 403, true}, + {"ok", request, ok, 200, false, nil}, + {"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"empty request", &api.SignRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -364,7 +355,9 @@ func TestClient_Sign(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.SignRequest) if err := api.ReadJSON(req.Body, body); err != nil { - api.WriteError(w, badRequest) + e, ok := tt.response.(error) + assert.Fatal(t, ok, "response expected to be error type") + api.WriteError(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -390,9 +383,7 @@ func TestClient_Sign(t *testing.T) { if got != nil { t.Errorf("Client.Sign() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Sign() error = %v, want %v", err, tt.response) - } + assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Sign() = %v, want %v", got, tt.response) @@ -409,19 +400,17 @@ func TestClient_Revoke(t *testing.T) { OTT: "the-ott", ReasonCode: 4, } - unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized")) - badRequest := errs.BadRequest(fmt.Errorf("Bad Request")) - tests := []struct { name string request *api.RevokeRequest response interface{} responseCode int wantErr bool + expectedErr error }{ - {"ok", request, ok, 200, false}, - {"unauthorized", request, unauthorized, 401, true}, - {"nil request", nil, badRequest, 403, true}, + {"ok", request, ok, 200, false, nil}, + {"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -438,7 +427,9 @@ func TestClient_Revoke(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.RevokeRequest) if err := api.ReadJSON(req.Body, body); err != nil { - api.WriteError(w, badRequest) + e, ok := tt.response.(error) + assert.Fatal(t, ok, "response expected to be error type") + api.WriteError(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -464,9 +455,7 @@ func TestClient_Revoke(t *testing.T) { if got != nil { t.Errorf("Client.Revoke() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Revoke() error = %v, want %v", err, tt.response) - } + assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) @@ -485,19 +474,18 @@ func TestClient_Renew(t *testing.T) { {Certificate: parseCertificate(rootPEM)}, }, } - unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized")) - badRequest := errs.BadRequest(fmt.Errorf("Bad Request")) tests := []struct { name string response interface{} responseCode int wantErr bool + err error }{ - {"ok", ok, 200, false}, - {"unauthorized", unauthorized, 401, true}, - {"empty request", badRequest, 403, true}, - {"nil request", badRequest, 403, true}, + {"ok", ok, 200, false, nil}, + {"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"empty request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"nil request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -527,9 +515,11 @@ func TestClient_Renew(t *testing.T) { if got != nil { t.Errorf("Client.Renew() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Renew() error = %v, want %v", err, tt.response) - } + + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -589,9 +579,7 @@ func TestClient_Provisioners(t *testing.T) { if got != nil { t.Errorf("Client.Provisioners() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Provisioners() error = %v, want %v", err, tt.response) - } + assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) @@ -605,7 +593,6 @@ func TestClient_ProvisionerKey(t *testing.T) { ok := &api.ProvisionerKeyResponse{ Key: "an encrypted key", } - notFound := errs.NotFound(fmt.Errorf("Not Found")) tests := []struct { name string @@ -613,9 +600,10 @@ func TestClient_ProvisionerKey(t *testing.T) { response interface{} responseCode int wantErr bool + err error }{ - {"ok", "kid", ok, 200, false}, - {"fail", "invalid", notFound, 500, true}, + {"ok", "kid", ok, 200, false, nil}, + {"fail", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -648,9 +636,11 @@ func TestClient_ProvisionerKey(t *testing.T) { if got != nil { t.Errorf("Client.ProvisionerKey() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.ProvisionerKey() error = %v, want %v", err, tt.response) - } + + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response) @@ -666,19 +656,17 @@ func TestClient_Roots(t *testing.T) { {Certificate: parseCertificate(rootPEM)}, }, } - unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized")) - badRequest := errs.BadRequest(fmt.Errorf("Bad Request")) tests := []struct { name string response interface{} responseCode int wantErr bool + err error }{ - {"ok", ok, 200, false}, - {"unauthorized", unauthorized, 401, true}, - {"empty request", badRequest, 403, true}, - {"nil request", badRequest, 403, true}, + {"ok", ok, 200, false, nil}, + {"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"bad-request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -708,9 +696,10 @@ func TestClient_Roots(t *testing.T) { if got != nil { t.Errorf("Client.Roots() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Roots() error = %v, want %v", err, tt.response) - } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Roots() = %v, want %v", got, tt.response) @@ -726,19 +715,16 @@ func TestClient_Federation(t *testing.T) { {Certificate: parseCertificate(rootPEM)}, }, } - unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized")) - badRequest := errs.BadRequest(fmt.Errorf("Bad Request")) tests := []struct { name string response interface{} responseCode int wantErr bool + err error }{ - {"ok", ok, 200, false}, - {"unauthorized", unauthorized, 401, true}, - {"empty request", badRequest, 403, true}, - {"nil request", badRequest, 403, true}, + {"ok", ok, 200, false, nil}, + {"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -768,9 +754,10 @@ func TestClient_Federation(t *testing.T) { if got != nil { t.Errorf("Client.Federation() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Federation() error = %v, want %v", err, tt.response) - } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Federation() = %v, want %v", got, tt.response) @@ -790,16 +777,16 @@ func TestClient_SSHRoots(t *testing.T) { HostKeys: []api.SSHPublicKey{{PublicKey: key}}, UserKeys: []api.SSHPublicKey{{PublicKey: key}}, } - notFound := errs.NotFound(fmt.Errorf("Not Found")) tests := []struct { name string response interface{} responseCode int wantErr bool + err error }{ - {"ok", ok, 200, false}, - {"not found", notFound, 404, true}, + {"ok", ok, 200, false, nil}, + {"not found", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -829,9 +816,10 @@ func TestClient_SSHRoots(t *testing.T) { if got != nil { t.Errorf("Client.SSHKeys() = %v, want nil", got) } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.SSHKeys() error = %v, want %v", err, tt.response) - } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response) @@ -948,7 +936,6 @@ func TestClient_SSHBastion(t *testing.T) { Hostname: "bastion.local", }, } - badRequest := errs.BadRequest(fmt.Errorf("Bad Request")) tests := []struct { name string @@ -956,11 +943,11 @@ func TestClient_SSHBastion(t *testing.T) { response interface{} responseCode int wantErr bool + err error }{ - {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false}, - {"bad response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true}, - {"empty request", &api.SSHBastionRequest{}, badRequest, 403, true}, - {"nil request", nil, badRequest, 403, true}, + {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil}, + {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil}, + {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -990,8 +977,11 @@ func TestClient_SSHBastion(t *testing.T) { if got != nil { t.Errorf("Client.SSHBastion() = %v, want nil", got) } - if tt.responseCode != 200 && !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.SSHBastion() error = %v, want %v", err, tt.response) + if tt.responseCode != 200 { + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, tt.err.Error(), err.Error()) } default: if !reflect.DeepEqual(got, tt.response) { diff --git a/ca/identity/identity_test.go b/ca/identity/identity_test.go index 3c04f982..139c6917 100644 --- a/ca/identity/identity_test.go +++ b/ca/identity/identity_test.go @@ -276,6 +276,7 @@ func TestIdentity_Renew(t *testing.T) { } oldIdentityDir := identityDir + identityDir = "testdata/identity" defer func() { identityDir = oldIdentityDir os.RemoveAll(tmpDir) diff --git a/db/db.go b/db/db.go index 8753cc1a..0934bffa 100644 --- a/db/db.go +++ b/db/db.go @@ -270,6 +270,105 @@ func (db *DB) Shutdown() error { return nil } +// MockAuthDB mocks the AuthDB interface. // +type MockAuthDB struct { + Err error + Ret1 interface{} + MIsRevoked func(string) (bool, error) + MIsSSHRevoked func(string) (bool, error) + MRevoke func(rci *RevokedCertificateInfo) error + MRevokeSSH func(rci *RevokedCertificateInfo) error + MStoreCertificate func(crt *x509.Certificate) error + MUseToken func(id, tok string) (bool, error) + MIsSSHHost func(principal string) (bool, error) + MStoreSSHCertificate func(crt *ssh.Certificate) error + MGetSSHHostPrincipals func() ([]string, error) + MShutdown func() error +} + +// IsRevoked mock. +func (m *MockAuthDB) IsRevoked(sn string) (bool, error) { + if m.MIsRevoked != nil { + return m.MIsRevoked(sn) + } + return m.Ret1.(bool), m.Err +} + +// IsSSHRevoked mock. +func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) { + if m.MIsSSHRevoked != nil { + return m.MIsSSHRevoked(sn) + } + return m.Ret1.(bool), m.Err +} + +// UseToken mock. +func (m *MockAuthDB) UseToken(id, tok string) (bool, error) { + if m.MUseToken != nil { + return m.MUseToken(id, tok) + } + if m.Ret1 == nil { + return false, m.Err + } + return m.Ret1.(bool), m.Err +} + +// Revoke mock. +func (m *MockAuthDB) Revoke(rci *RevokedCertificateInfo) error { + if m.MRevoke != nil { + return m.MRevoke(rci) + } + return m.Err +} + +// RevokeSSH mock. +func (m *MockAuthDB) RevokeSSH(rci *RevokedCertificateInfo) error { + if m.MRevokeSSH != nil { + return m.MRevokeSSH(rci) + } + return m.Err +} + +// StoreCertificate mock. +func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error { + if m.MStoreCertificate != nil { + return m.MStoreCertificate(crt) + } + return m.Err +} + +// IsSSHHost mock. +func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) { + if m.MIsSSHHost != nil { + return m.MIsSSHHost(principal) + } + return m.Ret1.(bool), m.Err +} + +// StoreSSHCertificate mock. +func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error { + if m.MStoreSSHCertificate != nil { + return m.MStoreSSHCertificate(crt) + } + return m.Err +} + +// GetSSHHostPrincipals mock. +func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) { + if m.MGetSSHHostPrincipals != nil { + return m.MGetSSHHostPrincipals() + } + return m.Ret1.([]string), m.Err +} + +// Shutdown mock. +func (m *MockAuthDB) Shutdown() error { + if m.MShutdown != nil { + return m.MShutdown() + } + return m.Err +} + // MockNoSQLDB // type MockNoSQLDB struct { Err error diff --git a/errs/error.go b/errs/error.go index 825cf549..adae017e 100644 --- a/errs/error.go +++ b/errs/error.go @@ -21,9 +21,9 @@ type StackTracer interface { // Option modifies the Error type. type Option func(e *Error) error -// WithMessage returns an Option that modifies the error by overwriting the +// withDefaultMessage returns an Option that modifies the error by overwriting the // message only if it is empty. -func WithMessage(format string, args ...interface{}) Option { +func withDefaultMessage(format string, args ...interface{}) Option { return func(e *Error) error { if len(e.Msg) > 0 { return e @@ -33,25 +33,52 @@ func WithMessage(format string, args ...interface{}) Option { } } +// WithMessage returns an Option that modifies the error by overwriting the +// message only if it is empty. +func WithMessage(format string, args ...interface{}) Option { + return func(e *Error) error { + e.Msg = fmt.Sprintf(format, args...) + return e + } +} + +// WithKeyVal returns an Option that adds the given key-value pair to the +// Error details. This is helpful for debugging errors. +func WithKeyVal(key string, val interface{}) Option { + return func(e *Error) error { + if e.Details == nil { + e.Details = make(map[string]interface{}) + } + e.Details[key] = val + return e + } +} + // Error represents the CA API errors. type Error struct { - Status int - Err error - Msg string + Status int + Err error + Msg string + Details map[string]interface{} } // New returns a new Error. If the given error implements the StatusCoder // interface we will ignore the given status. func New(status int, err error, opts ...Option) error { - var e *Error - if sc, ok := err.(StatusCoder); ok { - e = &Error{Status: sc.StatusCode(), Err: err} - } else { - cause := errors.Cause(err) - if sc, ok := cause.(StatusCoder); ok { + var ( + e *Error + ok bool + ) + if e, ok = err.(*Error); !ok { + if sc, ok := err.(StatusCoder); ok { e = &Error{Status: sc.StatusCode(), Err: err} } else { - e = &Error{Status: status, Err: err} + cause := errors.Cause(err) + if sc, ok := cause.(StatusCoder); ok { + e = &Error{Status: sc.StatusCode(), Err: err} + } else { + e = &Error{Status: status, Err: err} + } } } for _, o := range opts { @@ -188,63 +215,62 @@ func StatusCodeError(code int, e error, opts ...Option) error { } } -var seeLogs = "Please see the certificate authority logs for more info." +var ( + seeLogs = "Please see the certificate authority logs for more info." + // BadRequestDefaultMsg 400 default msg + BadRequestDefaultMsg = "The request could not be completed due to being poorly formatted or missing critical data. " + seeLogs + // UnauthorizedDefaultMsg 401 default msg + UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs + // ForbiddenDefaultMsg 403 default msg + ForbiddenDefaultMsg = "The request was forbidden by the certificate authority. " + seeLogs + // NotFoundDefaultMsg 404 default msg + NotFoundDefaultMsg = "The requested resource could not be found. " + seeLogs + // InternalServerErrorDefaultMsg 500 default msg + InternalServerErrorDefaultMsg = "The certificate authority encountered an Internal Server Error. " + seeLogs + // NotImplementedDefaultMsg 501 default msg + NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs +) // InternalServerError returns a 500 error with the given error. func InternalServerError(err error, opts ...Option) error { - if len(opts) == 0 { - opts = append(opts, WithMessage("The certificate authority encountered an Internal Server Error. "+seeLogs)) - } + opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg)) return New(http.StatusInternalServerError, err, opts...) } // NotImplemented returns a 501 error with the given error. func NotImplemented(err error, opts ...Option) error { - if len(opts) == 0 { - opts = append(opts, WithMessage("The requested method is not implemented by the certificate authority. "+seeLogs)) - } + opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg)) return New(http.StatusNotImplemented, err, opts...) } // BadRequest returns an 400 error with the given error. func BadRequest(err error, opts ...Option) error { - if len(opts) == 0 { - opts = append(opts, WithMessage("The request could not be completed due to being poorly formatted or "+ - "missing critical data. "+seeLogs)) - } + opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) return New(http.StatusBadRequest, err, opts...) } // Unauthorized returns an 401 error with the given error. func Unauthorized(err error, opts ...Option) error { - if len(opts) == 0 { - opts = append(opts, WithMessage("The request lacked necessary authorization to be completed. "+seeLogs)) - } + opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg)) return New(http.StatusUnauthorized, err, opts...) } // Forbidden returns an 403 error with the given error. func Forbidden(err error, opts ...Option) error { - if len(opts) == 0 { - opts = append(opts, WithMessage("The request was Forbidden by the certificate authority. "+seeLogs)) - } + opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg)) return New(http.StatusForbidden, err, opts...) } // NotFound returns an 404 error with the given error. func NotFound(err error, opts ...Option) error { - if len(opts) == 0 { - opts = append(opts, WithMessage("The requested resource could not be found. "+seeLogs)) - } + opts = append(opts, withDefaultMessage(NotFoundDefaultMsg)) return New(http.StatusNotFound, err, opts...) } // UnexpectedError will be used when the certificate authority makes an outgoing // request and receives an unhandled status code. func UnexpectedError(code int, err error, opts ...Option) error { - if len(opts) == 0 { - opts = append(opts, WithMessage("The certificate authority received an "+ - "unexpected HTTP status code - '%d'. "+seeLogs, code)) - } + opts = append(opts, withDefaultMessage("The certificate authority received an "+ + "unexpected HTTP status code - '%d'. "+seeLogs, code)) return New(code, err, opts...) } diff --git a/api/errors_test.go b/errs/errors_test.go similarity index 87% rename from api/errors_test.go rename to errs/errors_test.go index 1f63142a..58b95437 100644 --- a/api/errors_test.go +++ b/errs/errors_test.go @@ -1,11 +1,9 @@ -package api +package errs import ( "fmt" "reflect" "testing" - - "github.com/smallstep/certificates/errs" ) func TestError_MarshalJSON(t *testing.T) { @@ -24,7 +22,7 @@ func TestError_MarshalJSON(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - e := &errs.Error{ + e := &Error{ Status: tt.fields.Status, Err: tt.fields.Err, } @@ -47,15 +45,15 @@ func TestError_UnmarshalJSON(t *testing.T) { tests := []struct { name string args args - expected *errs.Error + expected *Error wantErr bool }{ - {"ok", args{[]byte(`{"status":400,"message":"bad request"}`)}, &errs.Error{Status: 400, Err: fmt.Errorf("bad request")}, false}, - {"fail", args{[]byte(`{"status":"400","message":"bad request"}`)}, &errs.Error{}, true}, + {"ok", args{[]byte(`{"status":400,"message":"bad request"}`)}, &Error{Status: 400, Err: fmt.Errorf("bad request")}, false}, + {"fail", args{[]byte(`{"status":"400","message":"bad request"}`)}, &Error{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - e := new(errs.Error) + e := new(Error) if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) }