From b26587705025ed7823a2f002a2cc0d9fa669860a Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 23 Jan 2020 22:04:34 -0800 Subject: [PATCH] Simplify statuscoder error generators. --- .golangci.yml | 1 + api/api.go | 12 +- api/api_test.go | 8 +- api/renew.go | 5 +- api/revoke.go | 17 +- api/revoke_test.go | 2 +- api/sign.go | 15 +- api/ssh.go | 46 ++-- api/sshRekey.go | 14 +- api/sshRenew.go | 12 +- api/sshRevoke.go | 15 +- api/utils.go | 3 +- authority/authorize.go | 25 ++- authority/authorize_test.go | 6 +- authority/provisioner/acme.go | 2 +- authority/provisioner/aws.go | 28 +-- authority/provisioner/aws_test.go | 2 +- authority/provisioner/azure.go | 22 +- authority/provisioner/azure_test.go | 2 +- authority/provisioner/gcp.go | 30 +-- authority/provisioner/gcp_test.go | 2 +- authority/provisioner/jwk.go | 18 +- authority/provisioner/jwk_test.go | 4 +- authority/provisioner/k8sSA.go | 14 +- authority/provisioner/k8sSA_test.go | 4 +- authority/provisioner/noop_test.go | 4 +- authority/provisioner/oidc.go | 24 +-- authority/provisioner/provisioner.go | 14 +- authority/provisioner/sign_ssh_options.go | 60 +++--- .../provisioner/sign_ssh_options_test.go | 10 +- authority/provisioner/ssh_test.go | 12 +- authority/provisioner/sshpop.go | 28 +-- authority/provisioner/sshpop_test.go | 4 +- authority/provisioner/x5c.go | 20 +- authority/provisioner/x5c_test.go | 12 +- authority/provisioners.go | 12 +- authority/provisioners_test.go | 28 ++- authority/root.go | 12 +- authority/root_test.go | 24 +-- authority/ssh.go | 72 +++---- authority/ssh_test.go | 14 +- authority/tls.go | 16 +- ca/client.go | 10 +- ca/client_test.go | 40 ++-- errs/error.go | 199 ++++++++++++------ 45 files changed, 483 insertions(+), 441 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index f0c2eed0..0aed855d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -63,6 +63,7 @@ issues: - declaration of "err" shadows declaration at line - should have a package comment, unless it's in another file for this package - error strings should not be capitalized or end with punctuation or a newline + - Wrapf call needs 1 arg but has 2 args # golangci.com configuration # https://github.com/golangci/golangci/wiki/Configuration service: diff --git a/api/api.go b/api/api.go index c4b307b3..37222be8 100644 --- a/api/api.go +++ b/api/api.go @@ -295,7 +295,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { // Load root certificate with the cert, err := h.Authority.Root(sum) if err != nil { - WriteError(w, errs.NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI))) + WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return } @@ -314,13 +314,13 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := parseCursor(r) if err != nil { - WriteError(w, errs.BadRequest(err)) + WriteError(w, errs.BadRequestErr(err)) return } p, next, err := h.Authority.GetProvisioners(cursor, limit) if err != nil { - WriteError(w, errs.InternalServerError(err)) + WriteError(w, errs.InternalServerErr(err)) return } JSON(w, &ProvisionersResponse{ @@ -334,7 +334,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") key, err := h.Authority.GetEncryptedKey(kid) if err != nil { - WriteError(w, errs.NotFound(err)) + WriteError(w, errs.NotFoundErr(err)) return } JSON(w, &ProvisionerKeyResponse{key}) @@ -344,7 +344,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { roots, err := h.Authority.GetRoots() if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } @@ -362,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { federated, err := h.Authority.GetFederation() if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } diff --git a/api/api_test.go b/api/api_test.go index 9f40a8e0..cbaf806f 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -915,7 +915,7 @@ func Test_caHandler_Renew(t *testing.T) { {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, - {"renew error", cs, nil, nil, errs.Forbidden(fmt.Errorf("an error")), http.StatusForbidden}, + {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, } expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) @@ -1010,10 +1010,10 @@ func Test_caHandler_Provisioners(t *testing.T) { t.Fatal(err) } - expectedError400 := errs.BadRequest(errors.New("force")) + expectedError400 := errs.BadRequest("force") expectedError400Bytes, err := json.Marshal(expectedError400) assert.FatalError(t, err) - expectedError500 := errs.InternalServerError(errors.New("force")) + expectedError500 := errs.InternalServer("force") expectedError500Bytes, err := json.Marshal(expectedError500) assert.FatalError(t, err) for _, tt := range tests { @@ -1082,7 +1082,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { } expected := []byte(`{"key":"` + privKey + `"}`) - expectedError404 := errs.NotFound(errors.New("force")) + expectedError404 := errs.NotFound("force") expectedError404Bytes, err := json.Marshal(expectedError404) assert.FatalError(t, err) diff --git a/api/renew.go b/api/renew.go index bc42ec24..bf32518b 100644 --- a/api/renew.go +++ b/api/renew.go @@ -3,7 +3,6 @@ package api import ( "net/http" - "github.com/pkg/errors" "github.com/smallstep/certificates/errs" ) @@ -11,7 +10,7 @@ import ( // new one. func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest(errors.New("missing peer certificate"))) + WriteError(w, errs.BadRequest("missing peer certificate")) return } @@ -22,7 +21,7 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { } certChainPEM := certChainToPEM(certChain) var caPEM Certificate - if len(certChainPEM) > 0 { + if len(certChainPEM) > 1 { caPEM = certChainPEM[1] } diff --git a/api/revoke.go b/api/revoke.go index df974cbe..547ed366 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -4,7 +4,6 @@ import ( "context" "net/http" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -30,13 +29,13 @@ type RevokeRequest struct { // or an error if something is wrong. func (r *RevokeRequest) Validate() (err error) { if r.Serial == "" { - return errs.BadRequest(errors.New("missing serial")) + return errs.BadRequest("missing serial") } if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { - return errs.BadRequest(errors.New("reasonCode out of bounds")) + return errs.BadRequest("reasonCode out of bounds") } if !r.Passive { - return errs.NotImplemented(errors.New("non-passive revocation not implemented")) + return errs.NotImplemented("non-passive revocation not implemented") } return @@ -50,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) return } @@ -72,7 +71,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { if len(body.OTT) > 0 { logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.Unauthorized(err)) + WriteError(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT @@ -81,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest(errors.New("missing ott or peer certificate"))) + WriteError(w, errs.BadRequest("missing ott or peer certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - WriteError(w, errs.BadRequest(errors.New("revoke: serial number in mtls certificate different than body"))) + WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body")) return } // TODO: should probably be checking if the certificate was revoked here. @@ -97,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { } if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } diff --git a/api/revoke_test.go b/api/revoke_test.go index e6aef11a..f44acebf 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -190,7 +190,7 @@ func Test_caHandler_Revoke(t *testing.T) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { - return errs.InternalServerError(errors.New("force")) + return errs.InternalServer("force") }, }, } diff --git a/api/sign.go b/api/sign.go index e76f6256..f30b0b4b 100644 --- a/api/sign.go +++ b/api/sign.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "net/http" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/tlsutil" @@ -22,13 +21,13 @@ type SignRequest struct { // or an error if something is wrong. func (s *SignRequest) Validate() error { if s.CsrPEM.CertificateRequest == nil { - return errs.BadRequest(errors.New("missing csr")) + return errs.BadRequest("missing csr") } if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { - return errs.BadRequest(errors.Wrap(err, "invalid csr")) + return errs.Wrap(http.StatusBadRequest, err, "invalid csr") } if s.OTT == "" { - return errs.BadRequest(errors.New("missing ott")) + return errs.BadRequest("missing ott") } return nil @@ -49,7 +48,7 @@ type SignResponse struct { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) return } @@ -66,18 +65,18 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { signOpts, err := h.Authority.AuthorizeSign(body.OTT) if err != nil { - WriteError(w, errs.Unauthorized(err)) + WriteError(w, errs.UnauthorizedErr(err)) return } certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } certChainPEM := certChainToPEM(certChain) var caPEM Certificate - if len(certChainPEM) > 0 { + if len(certChainPEM) > 1 { caPEM = certChainPEM[1] } logCertificate(w, certChain[0]) diff --git a/api/ssh.go b/api/ssh.go index 2206973b..f0b090d1 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -249,19 +249,19 @@ type SSHBastionResponse struct { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequest(err)) + WriteError(w, errs.BadRequestErr(err)) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) return } @@ -269,7 +269,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing addUserPublicKey"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey")) return } } @@ -285,13 +285,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.Unauthorized(err)) + WriteError(w, errs.UnauthorizedErr(err)) return } cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...) if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } @@ -299,7 +299,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 { addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert) if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } addUserCertificate = &SSHCertificate{addUserCert} @@ -320,12 +320,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.Unauthorized(err)) + WriteError(w, errs.UnauthorizedErr(err)) return } certChain, err := h.Authority.Sign(cr, opts, signOpts...) if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } identityCertificate = certChainToPEM(certChain) @@ -343,12 +343,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHRoots() if err != nil { - WriteError(w, errs.InternalServerError(err)) + WriteError(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound(errors.New("no keys found"))) + WriteError(w, errs.NotFound("no keys found")) return } @@ -368,12 +368,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHFederation() if err != nil { - WriteError(w, errs.InternalServerError(err)) + WriteError(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound(errors.New("no keys found"))) + WriteError(w, errs.NotFound("no keys found")) return } @@ -393,17 +393,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequest(err)) + WriteError(w, errs.BadRequestErr(err)) return } ts, err := h.Authority.GetSSHConfig(body.Type, body.Data) if err != nil { - WriteError(w, errs.InternalServerError(err)) + WriteError(w, errs.InternalServerErr(err)) return } @@ -414,7 +414,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { case provisioner.SSHHostCert: config.HostTemplates = ts default: - WriteError(w, errs.InternalServerError(errors.New("it should hot get here"))) + WriteError(w, errs.InternalServer("it should hot get here")) return } @@ -429,13 +429,13 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequest(err)) + WriteError(w, errs.BadRequestErr(err)) return } exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) if err != nil { - WriteError(w, errs.InternalServerError(err)) + WriteError(w, errs.InternalServerErr(err)) return } JSON(w, &SSHCheckPrincipalResponse{ @@ -452,7 +452,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { hosts, err := h.Authority.GetSSHHosts(cert) if err != nil { - WriteError(w, errs.InternalServerError(err)) + WriteError(w, errs.InternalServerErr(err)) return } JSON(w, &SSHGetHostsResponse{ @@ -464,17 +464,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequest(err)) + WriteError(w, errs.BadRequestErr(err)) return } bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname) if err != nil { - WriteError(w, errs.InternalServerError(err)) + WriteError(w, errs.InternalServerErr(err)) return } diff --git a/api/sshRekey.go b/api/sshRekey.go index efeee141..a5cc1f06 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -40,42 +40,42 @@ type SSHRekeyResponse struct { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequest(err)) + WriteError(w, errs.BadRequestErr(err)) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) return } ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.Unauthorized(err)) + WriteError(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerError(err)) + WriteError(w, errs.InternalServerErr(err)) } newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...) if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } identity, err := h.renewIdentityCertificate(r) if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } diff --git a/api/sshRenew.go b/api/sshRenew.go index fd4ff1ee..11a9d8e8 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -36,36 +36,36 @@ type SSHRenewResponse struct { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequest(err)) + WriteError(w, errs.BadRequestErr(err)) return } ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod) _, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.Unauthorized(err)) + WriteError(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerError(err)) + WriteError(w, errs.InternalServerErr(err)) } newCert, err := h.Authority.RenewSSH(oldCert) if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } identity, err := h.renewIdentityCertificate(r) if err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index cd4a3a3e..b8d1dadd 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -4,7 +4,6 @@ import ( "context" "net/http" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -30,16 +29,16 @@ type SSHRevokeRequest struct { // or an error if something is wrong. func (r *SSHRevokeRequest) Validate() (err error) { if r.Serial == "" { - return errs.BadRequest(errors.New("missing serial")) + return errs.BadRequest("missing serial") } if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { - return errs.BadRequest(errors.New("reasonCode out of bounds")) + return errs.BadRequest("reasonCode out of bounds") } if !r.Passive { - return errs.NotImplemented(errors.New("non-passive revocation not implemented")) + return errs.NotImplemented("non-passive revocation not implemented") } if len(r.OTT) == 0 { - return errs.BadRequest(errors.New("missing ott")) + return errs.BadRequest("missing ott") } return } @@ -50,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body"))) + WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) return } @@ -71,13 +70,13 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.Unauthorized(err)) + WriteError(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.Forbidden(err)) + WriteError(w, errs.ForbiddenErr(err)) return } diff --git a/api/utils.go b/api/utils.go index 56beb2b5..0d87a065 100644 --- a/api/utils.go +++ b/api/utils.go @@ -6,7 +6,6 @@ import ( "log" "net/http" - "github.com/pkg/errors" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" ) @@ -69,7 +68,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) { // pointed by v. func ReadJSON(r io.Reader, v interface{}) error { if err := json.NewDecoder(r).Decode(v); err != nil { - return errs.BadRequest(errors.Wrap(err, "error decoding json")) + return errs.Wrap(http.StatusBadRequest, err, "error decoding json") } return nil } diff --git a/authority/authorize.go b/authority/authorize.go index cdca026d..bda59520 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -6,7 +6,6 @@ import ( "net/http" "strings" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/jose" @@ -58,15 +57,15 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // This check is meant as a stopgap solution to the current lack of a persistence layer. if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { - return nil, errs.Unauthorized(errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority")) + return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority") } } // This method will also validate the audiences for JWK provisioners. p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) if !ok { - return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: provisioner "+ - "not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))) + return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+ + "not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")) } // Store the token to protect against reuse unless it's skipped. @@ -78,7 +77,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision "authority.authorizeToken: failed when attempting to store token") } if !ok { - return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: token already used")) + return nil, errs.Unauthorized("authority.authorizeToken: token already used") } } } @@ -89,7 +88,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // Authorize grabs the method from the context and authorizes the request by // validating the one-time-token. func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) { - var opts = []errs.Option{errs.WithKeyVal("token", token)} + var opts = []interface{}{errs.WithKeyVal("token", token)} switch m := provisioner.MethodFromContext(ctx); m { case provisioner.SignMethod: @@ -99,13 +98,13 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner. return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...) case provisioner.SSHSignMethod: if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...) + return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...) } _, err := a.authorizeSSHSign(ctx, token) return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) case provisioner.SSHRenewMethod: if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...) + return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...) } _, err := a.authorizeSSHRenew(ctx, token) return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) @@ -113,12 +112,12 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner. return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...) case provisioner.SSHRekeyMethod: if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...) + return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...) } _, signOpts, err := a.authorizeSSHRekey(ctx, token) return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) default: - return nil, errs.InternalServerError(errors.Errorf("authority.Authorize; method %d is not supported", m), opts...) + return nil, errs.InternalServer("authority.Authorize; method %d is not supported", append([]interface{}{m}, opts...)...) } } @@ -165,7 +164,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { // // TODO(mariano): should we authorize by default? func (a *Authority) authorizeRenew(cert *x509.Certificate) error { - var opts = []errs.Option{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())} + var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())} // Check the passive revocation table. isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String()) @@ -173,12 +172,12 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) } if isRevoked { - return errs.Unauthorized(errors.New("authority.authorizeRenew: certificate has been revoked"), opts...) + return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) } p, ok := a.provisioners.LoadByCertificate(cert) if !ok { - return errs.Unauthorized(errors.New("authority.authorizeRenew: provisioner not found"), opts...) + return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) } if err := p.AuthorizeRenew(context.Background(), cert); err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 6f7bf940..e4863764 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -180,7 +180,7 @@ func TestAuthority_authorizeToken(t *testing.T) { } raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) - _, err = _a.authorizeToken(context.TODO(), raw) + _, err = _a.authorizeToken(context.Background(), raw) assert.FatalError(t, err) return &authorizeTest{ auth: _a, @@ -268,7 +268,7 @@ func TestAuthority_authorizeToken(t *testing.T) { t.Run(name, func(t *testing.T) { tc := genTestCase(t) - p, err := tc.auth.authorizeToken(context.TODO(), tc.token) + p, err := tc.auth.authorizeToken(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) @@ -355,7 +355,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) { t.Run(name, func(t *testing.T) { tc := genTestCase(t) - if err := tc.auth.authorizeRevoke(context.TODO(), tc.token); err != nil { + if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 7adeb311..e414410b 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -80,7 +80,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // certificate was configured to allow renewals. func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized(errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())) + return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()) } return nil } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 39769118..16820909 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -306,7 +306,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // certificate was configured to allow renewals. func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized(errors.Errorf("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())) + return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID()) } return nil } @@ -353,7 +353,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token") } if len(jwt.Headers) == 0 { - return nil, errs.InternalServerError(errors.New("aws.authorizeToken; error parsing token, header is missing")) + return nil, errs.InternalServer("aws.authorizeToken; error parsing token, header is missing") } var unsafeClaims awsPayload @@ -378,13 +378,13 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { switch { case doc.AccountID == "": - return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document accountId cannot be empty")) + return nil, errs.Unauthorized("aws.authorizeToken; aws identity document accountId cannot be empty") case doc.InstanceID == "": - return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty")) + return nil, errs.Unauthorized("aws.authorizeToken; aws identity document instanceId cannot be empty") case doc.PrivateIP == "": - return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty")) + return nil, errs.Unauthorized("aws.authorizeToken; aws identity document privateIp cannot be empty") case doc.Region == "": - return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document region cannot be empty")) + return nil, errs.Unauthorized("aws.authorizeToken; aws identity document region cannot be empty") } // According to "rfc7519 JSON Web Token" acceptable skew should be no @@ -399,7 +399,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { // validate audiences with the defaults if !matchesAudience(payload.Audience, p.audiences.Sign) { - return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)")) + return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)") } // Validate subject, it has to be known if disableCustomSANs is enabled @@ -407,7 +407,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { if payload.Subject != doc.InstanceID && payload.Subject != doc.PrivateIP && payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) { - return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)")) + return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)") } } @@ -421,14 +421,14 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { } } if !found { - return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid")) + return nil, errs.Unauthorized("aws.authorizeToken; invalid aws identity document - accountId is not valid") } } // validate instance age if d := p.InstanceAge.Value(); d > 0 { if now.Sub(doc.PendingTime) > d { - return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document pendingTime is too old")) + return nil, errs.Unauthorized("aws.authorizeToken; aws identity document pendingTime is too old") } } @@ -439,7 +439,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized(errors.Errorf("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())) + return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID()) } claims, err := p.authorizeToken(token) if err != nil { @@ -462,7 +462,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, }, } // Validate user options - signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) + signOptions = append(signOptions, sshCertOptionsValidator(defaults)) // Set defaults if not given as user options signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) @@ -474,8 +474,8 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertificateValidityValidator{p.claimer}, + &sshCertValidityValidator{p.claimer}, // Require all the fields in the SSH certificate - &sshCertificateDefaultValidator{}, + &sshCertDefaultValidator{}, ), nil } diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 8c59bebe..5e9ea92c 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -704,7 +704,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.aws.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { + if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { sc, ok := err.(errs.StatusCoder) diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 86eb516f..88755c2a 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -210,14 +210,14 @@ func (p *Azure) Init(config Config) (err error) { return nil } -// authorizeToken returs the claims, name, group, error. +// authorizeToken returns the claims, name, group, error. func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token") } if len(jwt.Headers) == 0 { - return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token missing header")) + return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header") } var found bool @@ -230,7 +230,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err } } if !found { - return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; cannot validate azure token")) + return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token") } if err := claims.ValidateWithLeeway(jose.Expected{ @@ -243,12 +243,12 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err // Validate TenantID if claims.TenantID != p.TenantID { - return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")) + return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)") } re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) if len(re) != 4 { - return nil, "", "", errs.Unauthorized(errors.Errorf("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)) + return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID) } group, name := re[2], re[3] return &claims, name, group, nil @@ -272,7 +272,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, } } if !found { - return nil, errs.Unauthorized(errors.New("azure.AuthorizeSign; azure token validation failed - invalid resource group")) + return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid resource group") } } @@ -302,7 +302,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // certificate was configured to allow renewals. func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized(errors.Errorf("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())) + return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID()) } return nil } @@ -310,7 +310,7 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized(errors.Errorf("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())) + return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID()) } _, name, _, err := p.authorizeToken(token) @@ -328,7 +328,7 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio Principals: []string{name}, } // Validate user options - signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) + signOptions = append(signOptions, sshCertOptionsValidator(defaults)) // Set defaults if not given as user options signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) @@ -340,9 +340,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertificateValidityValidator{p.claimer}, + &sshCertValidityValidator{p.claimer}, // Require all the fields in the SSH certificate - &sshCertificateDefaultValidator{}, + &sshCertDefaultValidator{}, ), nil } diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 13e6ac8e..f49624cc 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -488,7 +488,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.azure.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { + if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { sc, ok := err.(errs.StatusCoder) diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index 69a3006a..d55b702f 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -243,7 +243,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // AuthorizeRenew returns an error if the renewal is disabled. func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized(errors.Errorf("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())) + return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID()) } return nil } @@ -264,7 +264,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token") } if len(jwt.Headers) == 0 { - return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; error parsing gcp token - header is missing")) + return nil, errs.Unauthorized("gcp.authorizeToken; error parsing gcp token - header is missing") } var found bool @@ -278,7 +278,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } } if !found { - return nil, errs.Unauthorized(errors.Errorf("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid)) + return nil, errs.Unauthorized("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid) } // According to "rfc7519 JSON Web Token" acceptable skew should be no @@ -293,7 +293,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { // validate audiences with the defaults if !matchesAudience(claims.Audience, p.audiences.Sign) { - return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")) + return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)") } // validate subject (service account) @@ -306,7 +306,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } } if !found { - return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim")) + return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid subject claim") } } @@ -320,26 +320,26 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } } if !found { - return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid project id")) + return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid project id") } } // validate instance age if d := p.InstanceAge.Value(); d > 0 { if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d { - return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old")) + return nil, errs.Unauthorized("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old") } } switch { case claims.Google.ComputeEngine.InstanceID == "": - return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty")) + return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty") case claims.Google.ComputeEngine.InstanceName == "": - return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty")) + return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty") case claims.Google.ComputeEngine.ProjectID == "": - return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty")) + return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty") case claims.Google.ComputeEngine.Zone == "": - return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty")) + return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty") } return &claims, nil @@ -348,7 +348,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized(errors.Errorf("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())) + return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID()) } claims, err := p.authorizeToken(token) if err != nil { @@ -371,7 +371,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, }, } // Validate user options - signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) + signOptions = append(signOptions, sshCertOptionsValidator(defaults)) // Set defaults if not given as user options signOptions = append(signOptions, sshCertDefaultsModifier(defaults)) @@ -383,8 +383,8 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertificateValidityValidator{p.claimer}, + &sshCertValidityValidator{p.claimer}, // Require all the fields in the SSH certificate - &sshCertificateDefaultValidator{}, + &sshCertDefaultValidator{}, ), nil } diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index bdda8fd9..0fbb4b41 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -680,7 +680,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { + if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { sc, ok := err.(errs.StatusCoder) diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 1c613de6..57297f78 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -120,12 +120,12 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err // validate audiences with the defaults if !matchesAudience(claims.Audience, audiences) { - return nil, errs.Unauthorized(errors.Errorf("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s", - audiences, claims.Audience)) + return nil, errs.Unauthorized("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s", + audiences, claims.Audience) } if claims.Subject == "" { - return nil, errs.Unauthorized(errors.New("jwk.authorizeToken; jwk token subject cannot be empty")) + return nil, errs.Unauthorized("jwk.authorizeToken; jwk token subject cannot be empty") } return &claims, nil @@ -173,7 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized(errors.Errorf("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())) + return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID()) } return nil } @@ -181,20 +181,20 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized(errors.Errorf("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())) + return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID()) } claims, err := p.authorizeToken(token, p.audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") } if claims.Step == nil || claims.Step.SSH == nil { - return nil, errs.Unauthorized(errors.New("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token")) + return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token") } opts := claims.Step.SSH signOptions := []SignOption{ // validates user's SSHOptions with the ones in the token - sshCertificateOptionsValidator(*opts), + sshCertOptionsValidator(*opts), } t := now() @@ -231,9 +231,9 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertificateValidityValidator{p.claimer}, + &sshCertValidityValidator{p.claimer}, // Require and validate all the default fields in the SSH certificate. - &sshCertificateDefaultValidator{}, + &sshCertDefaultValidator{}, ), nil } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index a0c48ee9..ed97d8f1 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -222,7 +222,7 @@ func TestJWK_AuthorizeRevoke(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token); err != nil { + if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil { if assert.NotNil(t, tt.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") @@ -337,7 +337,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr { + if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { sc, ok := err.(errs.StatusCoder) diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index 0826028e..b63ce979 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -149,7 +149,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, claims k8sSAPayload ) if p.pubKeys == nil { - return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented")) + return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented") /* NOTE: We plan to support the TokenReview API in a future release. Below is some code that should be useful when we prioritize this integration. @@ -177,7 +177,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, } } if !valid { - return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims")) + return nil, errs.Unauthorized("k8ssa.authorizeToken; error validating k8sSA token and extracting claims") } // According to "rfc7519 JSON Web Token" acceptable skew should be no @@ -189,7 +189,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, } if claims.Subject == "" { - return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA token subject cannot be empty")) + return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA token subject cannot be empty") } return &claims, nil @@ -221,7 +221,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // AuthorizeRenew returns an error if the renewal is disabled. func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())) + return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()) } return nil } @@ -229,7 +229,7 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro // AuthorizeSSHSign validates an request for an SSH certificate. func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())) + return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()) } if _, err := p.authorizeToken(token, p.audiences.SSHSign); err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") @@ -246,9 +246,9 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertificateValidityValidator{p.claimer}, + &sshCertValidityValidator{p.claimer}, // Require and validate all the default fields in the SSH certificate. - &sshCertificateDefaultValidator{}, + &sshCertDefaultValidator{}, ), nil } diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 09a856c5..f1d12b4a 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -363,10 +363,10 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case sshCertDefaultsModifier: assert.Equals(t, v.CertType, SSHUserCert) case *sshDefaultExtensionModifier: - case *sshCertificateValidityValidator: + case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.claimer) case *sshDefaultPublicKeyValidator: - case *sshCertificateDefaultValidator: + case *sshCertDefaultValidator: case *sshDefaultDuration: assert.Equals(t, v.Claimer, tc.p.claimer) default: diff --git a/authority/provisioner/noop_test.go b/authority/provisioner/noop_test.go index c79e7460..19e4d235 100644 --- a/authority/provisioner/noop_test.go +++ b/authority/provisioner/noop_test.go @@ -14,8 +14,8 @@ func Test_noop(t *testing.T) { assert.Equals(t, "noop", p.GetName()) assert.Equals(t, noopType, p.GetType()) assert.Equals(t, nil, p.Init(Config{})) - assert.Equals(t, nil, p.AuthorizeRenew(context.TODO(), &x509.Certificate{})) - assert.Equals(t, nil, p.AuthorizeRevoke(context.TODO(), "foo")) + assert.Equals(t, nil, p.AuthorizeRenew(context.Background(), &x509.Certificate{})) + assert.Equals(t, nil, p.AuthorizeRevoke(context.Background(), "foo")) kid, key, ok := p.GetEncryptedKey() assert.Equals(t, "", kid) diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 87710ebb..0b5448af 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -195,12 +195,12 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error { // Validate azp if present if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID { - return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: invalid azp")) + return errs.Unauthorized("validatePayload: failed to validate oidc token payload: invalid azp") } // Enforce an email claim if p.Email == "" { - return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email not found")) + return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email not found") } // Validate domains (case-insensitive) @@ -214,7 +214,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error { } } if !found { - return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email is not allowed")) + return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email is not allowed") } } @@ -230,7 +230,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error { } } if !found { - return errs.Unauthorized(errors.New("validatePayload: oidc token payload validation failed: invalid group")) + return errs.Unauthorized("validatePayload: oidc token payload validation failed: invalid group") } } @@ -263,7 +263,7 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) { } } if !found { - return nil, errs.Unauthorized(errors.New("oidc.AuthorizeToken; cannot validate oidc token")) + return nil, errs.Unauthorized("oidc.AuthorizeToken; cannot validate oidc token") } if err := o.ValidatePayload(claims); err != nil { @@ -286,7 +286,7 @@ func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error { if o.IsAdmin(claims.Email) { return nil } - return errs.Unauthorized(errors.New("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")) + return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token") } // AuthorizeSign validates the given token. @@ -318,7 +318,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // certificate was configured to allow renewals. func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if o.claimer.IsDisableRenewal() { - return errs.Unauthorized(errors.Errorf("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())) + return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID()) } return nil } @@ -326,7 +326,7 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !o.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized(errors.Errorf("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())) + return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID()) } claims, err := o.authorizeToken(token) if err != nil { @@ -352,7 +352,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Non-admin users can only use principals returned by the identityFunc, and // can only sign user certificates. if !o.IsAdmin(claims.Email) { - signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) + signOptions = append(signOptions, sshCertOptionsValidator(defaults)) } // Default to a user certificate with usernames as principals if those options @@ -367,9 +367,9 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertificateValidityValidator{o.claimer}, + &sshCertValidityValidator{o.claimer}, // Require all the fields in the SSH certificate - &sshCertificateDefaultValidator{}, + &sshCertDefaultValidator{}, ), nil } @@ -382,7 +382,7 @@ func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error { // Only admins can revoke certificates. if !o.IsAdmin(claims.Email) { - return errs.Unauthorized(errors.New("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")) + return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token") } return nil } diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 40e1e309..fd342b01 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -284,43 +284,43 @@ type base struct{} // AuthorizeSign returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for signing x509 Certificates. func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSign not implemented")) + return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented") } // AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for revoking x509 Certificates. func (b *base) AuthorizeRevoke(ctx context.Context, token string) error { - return errs.Unauthorized(errors.New("provisioner.AuthorizeRevoke not implemented")) + return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented") } // AuthorizeRenew returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for renewing x509 Certificates. func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - return errs.Unauthorized(errors.New("provisioner.AuthorizeRenew not implemented")) + return errs.Unauthorized("provisioner.AuthorizeRenew not implemented") } // AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for signing SSH Certificates. func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHSign not implemented")) + return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented") } // AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for revoking SSH Certificates. func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error { - return errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRevoke not implemented")) + return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented") } // AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for renewing SSH Certificates. func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { - return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRenew not implemented")) + return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented") } // AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite // this method if they will support authorizing tokens for rekeying SSH Certificates. func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { - return nil, nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRekey not implemented")) + return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented") } // Identity is the type representing an externally supplied identity that is used diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index ec67baf1..b0ab78ea 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -19,29 +19,29 @@ const ( SSHHostCert = "host" ) -// SSHCertificateModifier is the interface used to change properties in an SSH +// SSHCertModifier is the interface used to change properties in an SSH // certificate. -type SSHCertificateModifier interface { +type SSHCertModifier interface { SignOption Modify(cert *ssh.Certificate) error } -// SSHCertificateOptionModifier is the interface used to add custom options used +// SSHCertOptionModifier is the interface used to add custom options used // to modify the SSH certificate. -type SSHCertificateOptionModifier interface { +type SSHCertOptionModifier interface { SignOption - Option(o SSHOptions) SSHCertificateModifier + Option(o SSHOptions) SSHCertModifier } -// SSHCertificateValidator is the interface used to validate an SSH certificate. -type SSHCertificateValidator interface { +// SSHCertValidator is the interface used to validate an SSH certificate. +type SSHCertValidator interface { SignOption Valid(cert *ssh.Certificate) error } -// SSHCertificateOptionsValidator is the interface used to validate the custom +// SSHCertOptionsValidator is the interface used to validate the custom // options used to modify the SSH certificate. -type SSHCertificateOptionsValidator interface { +type SSHCertOptionsValidator interface { SignOption Valid(got SSHOptions) error } @@ -69,7 +69,7 @@ func (o SSHOptions) Type() uint32 { return sshCertTypeUInt32(o.CertType) } -// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate. +// Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate. func (o SSHOptions) Modify(cert *ssh.Certificate) error { switch o.CertType { case "": // ignore @@ -116,7 +116,7 @@ func (o SSHOptions) match(got SSHOptions) error { return nil } -// sshCertPrincipalsModifier is an SSHCertificateModifier that sets the +// sshCertPrincipalsModifier is an SSHCertModifier that sets the // principals to the SSH certificate. type sshCertPrincipalsModifier []string @@ -126,7 +126,7 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error { return nil } -// sshCertKeyIDModifier is an SSHCertificateModifier that sets the given +// sshCertKeyIDModifier is an SSHCertModifier that sets the given // Key ID in the SSH certificate. type sshCertKeyIDModifier string @@ -135,7 +135,7 @@ func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error { return nil } -// sshCertTypeModifier is an SSHCertificateModifier that sets the +// sshCertTypeModifier is an SSHCertModifier that sets the // certificate type. type sshCertTypeModifier string @@ -145,7 +145,7 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error { return nil } -// sshCertValidAfterModifier is an SSHCertificateModifier that sets the +// sshCertValidAfterModifier is an SSHCertModifier that sets the // ValidAfter in the SSH certificate. type sshCertValidAfterModifier uint64 @@ -154,7 +154,7 @@ func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error { return nil } -// sshCertValidBeforeModifier is an SSHCertificateModifier that sets the +// sshCertValidBeforeModifier is an SSHCertModifier that sets the // ValidBefore in the SSH certificate. type sshCertValidBeforeModifier uint64 @@ -163,11 +163,11 @@ func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error { return nil } -// sshCertDefaultsModifier implements a SSHCertificateModifier that +// sshCertDefaultsModifier implements a SSHCertModifier that // modifies the certificate with the given options if they are not set. type sshCertDefaultsModifier SSHOptions -// Modify implements the SSHCertificateModifier interface. +// Modify implements the SSHCertModifier interface. func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error { if cert.CertType == 0 { cert.CertType = sshCertTypeUInt32(m.CertType) @@ -184,7 +184,7 @@ func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error { return nil } -// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets +// sshDefaultExtensionModifier implements an SSHCertModifier that sets // the default extensions in an SSH certificate. type sshDefaultExtensionModifier struct{} @@ -208,14 +208,14 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error { } } -// sshDefaultDuration is an SSHCertificateModifier that sets the certificate +// sshDefaultDuration is an SSHCertModifier that sets the certificate // ValidAfter and ValidBefore if they have not been set. It will fail if a // CertType has not been set or is not valid. type sshDefaultDuration struct { *Claimer } -func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertificateModifier { +func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertModifier { return sshModifierFunc(func(cert *ssh.Certificate) error { d, err := m.DefaultSSHCertDuration(cert.CertType) if err != nil { @@ -248,7 +248,7 @@ type sshLimitDuration struct { NotAfter time.Time } -func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier { +func (m *sshLimitDuration) Option(o SSHOptions) SSHCertModifier { if m.NotAfter.IsZero() { defaultDuration := &sshDefaultDuration{m.Claimer} return defaultDuration.Option(o) @@ -295,22 +295,22 @@ func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier { }) } -// sshCertificateOptionsValidator validates the user SSHOptions with the ones +// sshCertOptionsValidator validates the user SSHOptions with the ones // usually present in the token. -type sshCertificateOptionsValidator SSHOptions +type sshCertOptionsValidator SSHOptions -// Valid implements SSHCertificateOptionsValidator and returns nil if both +// Valid implements SSHCertOptionsValidator and returns nil if both // SSHOptions match. -func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error { +func (v sshCertOptionsValidator) Valid(got SSHOptions) error { want := SSHOptions(v) return want.match(got) } -type sshCertificateValidityValidator struct { +type sshCertValidityValidator struct { *Claimer } -func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error { +func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate) error { switch { case cert.ValidAfter == 0: return errors.New("ssh certificate validAfter cannot be 0") @@ -355,12 +355,12 @@ func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error { } } -// sshCertificateDefaultValidator implements a simple validator for all the +// sshCertDefaultValidator implements a simple validator for all the // fields in the SSH certificate. -type sshCertificateDefaultValidator struct{} +type sshCertDefaultValidator struct{} // Valid returns an error if the given certificate does not contain the necessary fields. -func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error { +func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate) error { switch { case len(cert.Nonce) == 0: return errors.New("ssh certificate nonce cannot be empty") diff --git a/authority/provisioner/sign_ssh_options_test.go b/authority/provisioner/sign_ssh_options_test.go index 87716e37..c13e46da 100644 --- a/authority/provisioner/sign_ssh_options_test.go +++ b/authority/provisioner/sign_ssh_options_test.go @@ -489,12 +489,12 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) { } } -func Test_sshCertificateDefaultValidator_Valid(t *testing.T) { +func Test_sshCertDefaultValidator_Valid(t *testing.T) { pub, _, err := keys.GenerateDefaultKeyPair() assert.FatalError(t, err) sshPub, err := ssh.NewPublicKey(pub) assert.FatalError(t, err) - v := sshCertificateDefaultValidator{} + v := sshCertDefaultValidator{} tests := []struct { name string cert *ssh.Certificate @@ -670,10 +670,10 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) { } } -func Test_sshCertificateValidityValidator(t *testing.T) { +func Test_sshCertValidityValidator(t *testing.T) { p, err := generateX5C(nil) assert.FatalError(t, err) - v := sshCertificateValidityValidator{p.claimer} + v := sshCertValidityValidator{p.claimer} n := now() tests := []struct { name string @@ -992,7 +992,7 @@ func Test_sshLimitDuration_Option(t *testing.T) { name string fields fields args args - want SSHCertificateModifier + want SSHCertModifier }{ // TODO: Add test cases. } diff --git a/authority/provisioner/ssh_test.go b/authority/provisioner/ssh_test.go index 1b31f78b..84860a75 100644 --- a/authority/provisioner/ssh_test.go +++ b/authority/provisioner/ssh_test.go @@ -45,22 +45,22 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp return nil, err } - var mods []SSHCertificateModifier - var validators []SSHCertificateValidator + var mods []SSHCertModifier + var validators []SSHCertValidator for _, op := range signOpts { switch o := op.(type) { // modify the ssh.Certificate - case SSHCertificateModifier: + case SSHCertModifier: mods = append(mods, o) // modify the ssh.Certificate given the SSHOptions - case SSHCertificateOptionModifier: + case SSHCertOptionModifier: mods = append(mods, o.Option(opts)) // validate the ssh.Certificate - case SSHCertificateValidator: + case SSHCertValidator: validators = append(validators, o) // validate the given SSHOptions - case SSHCertificateOptionsValidator: + case SSHCertOptionsValidator: if err := o.Valid(opts); err != nil { return nil, err } diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 3c55aada..db1c5a89 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -112,20 +112,20 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.authorizeToken; error checking checking sshpop cert revocation") } else if isRevoked { - return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate is revoked")) + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked") } // Check validity period of the certificate. n := time.Now() if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { - return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future")) + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") } if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { - return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past")) + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") } sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) if !ok { - return nil, errs.InternalServerError(errors.New("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")) + return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey") } pubKey := sshCryptoPubKey.CryptoPublicKey() @@ -146,7 +146,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa } } if !found { - return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate")) + return nil, errs.Unauthorized("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate") } // Using the ssh certificates key to validate the claims accomplishes two @@ -170,12 +170,12 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa // validate audiences with the defaults if !matchesAudience(claims.Audience, audiences) { - return nil, errs.Unauthorized(errors.Errorf("sshpop.authorizeToken; sshpop token has invalid audience "+ - "claim (aud): expected %s, but got %s", audiences, claims.Audience)) + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token has invalid audience "+ + "claim (aud): expected %s, but got %s", audiences, claims.Audience) } if claims.Subject == "" { - return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty")) + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token subject cannot be empty") } claims.sshCert = sshCert @@ -190,8 +190,8 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { - return errs.BadRequest(errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject " + - "must be equivalent to sshpop certificate serial number")) + return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " + + "must be equivalent to sshpop certificate serial number") } return nil } @@ -204,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { - return nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")) + return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate") } return claims.sshCert, nil @@ -219,15 +219,15 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } if claims.sshCert.CertType != ssh.HostCert { - return nil, nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")) + return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate") } return claims.sshCert, []SignOption{ // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertificateValidityValidator{p.claimer}, + &sshCertValidityValidator{p.claimer}, // Require and validate all the default fields in the SSH certificate. - &sshCertificateDefaultValidator{}, + &sshCertDefaultValidator{}, }, nil } diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 32f58879..5863b6f9 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -564,8 +564,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { for _, o := range opts { switch v := o.(type) { case *sshDefaultPublicKeyValidator: - case *sshCertificateDefaultValidator: - case *sshCertificateValidityValidator: + case *sshCertDefaultValidator: + case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.claimer) default: assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 692cd963..f00a215d 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -136,7 +136,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err leaf := verifiedChains[0][0] if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 { - return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")) + return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature") } // Using the leaf certificates key to validate the claims accomplishes two @@ -160,12 +160,12 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err // validate audiences with the defaults if !matchesAudience(claims.Audience, audiences) { - return nil, errs.Unauthorized(errors.Errorf("x5c.authorizeToken; x5c token has invalid audience "+ - "claim (aud); expected %s, but got %s", audiences, claims.Audience)) + return nil, errs.Unauthorized("x5c.authorizeToken; x5c token has invalid audience "+ + "claim (aud); expected %s, but got %s", audiences, claims.Audience) } if claims.Subject == "" { - return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; x5c token subject cannot be empty")) + return nil, errs.Unauthorized("x5c.authorizeToken; x5c token subject cannot be empty") } // Save the verified chains on the x5c payload object. @@ -213,7 +213,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // AuthorizeRenew returns an error if the renewal is disabled. func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized(errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID())) + return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID()) } return nil } @@ -221,7 +221,7 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized(errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID())) + return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID()) } claims, err := p.authorizeToken(token, p.audiences.SSHSign) @@ -230,13 +230,13 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, } if claims.Step == nil || claims.Step.SSH == nil { - return nil, errs.Unauthorized(errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token")) + return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token") } opts := claims.Step.SSH signOptions := []SignOption{ // validates user's SSHOptions with the ones in the token - sshCertificateOptionsValidator(*opts), + sshCertOptionsValidator(*opts), } // Add modifiers from custom claims @@ -272,8 +272,8 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertificateValidityValidator{p.claimer}, + &sshCertValidityValidator{p.claimer}, // Require all the fields in the SSH certificate - &sshCertificateDefaultValidator{}, + &sshCertDefaultValidator{}, ), nil } diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 775f3202..3ebaeb6b 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -548,7 +548,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil { + if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") @@ -594,7 +594,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if err := tc.p.AuthorizeRenew(context.TODO(), nil); err != nil { + if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") @@ -754,7 +754,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if opts, err := tc.p.AuthorizeSSHSign(context.TODO(), tc.token); err != nil { + if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") @@ -768,7 +768,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { nw := now() for _, o := range opts { switch v := o.(type) { - case sshCertificateOptionsValidator: + case sshCertOptionsValidator: tc.claims.Step.SSH.ValidAfter.t = time.Time{} tc.claims.Step.SSH.ValidBefore.t = time.Time{} assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH) @@ -787,10 +787,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { case *sshLimitDuration: assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) - case *sshCertificateValidityValidator: + case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.claimer) case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator, - *sshCertificateDefaultValidator: + *sshCertDefaultValidator: case sshCertKeyIDValidator: assert.Equals(t, string(v), "foo") default: diff --git a/authority/provisioners.go b/authority/provisioners.go index 2d43571b..99a85d46 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -2,18 +2,16 @@ package authority import ( "crypto/x509" - "net/http" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" ) // GetEncryptedKey returns the JWE key corresponding to the given kid argument. func (a *Authority) GetEncryptedKey(kid string) (string, error) { key, ok := a.provisioners.LoadEncryptedKey(kid) if !ok { - return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid), - http.StatusNotFound, apiCtx{}} + return "", errs.NotFound("encrypted key with kid %s was not found", kid) } return key, nil } @@ -30,8 +28,7 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) { p, ok := a.provisioners.LoadByCertificate(crt) if !ok { - return nil, &apiError{errors.Errorf("provisioner not found"), - http.StatusNotFound, apiCtx{}} + return nil, errs.NotFound("provisioner not found") } return p, nil } @@ -40,8 +37,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { p, ok := a.provisioners.Load(id) if !ok { - return nil, &apiError{errors.Errorf("provisioner not found"), - http.StatusNotFound, apiCtx{}} + return nil, errs.NotFound("provisioner not found") } return p, nil } diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index fb84a31d..1a45f209 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -7,13 +7,15 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" ) func TestGetEncryptedKey(t *testing.T) { type ek struct { - a *Authority - kid string - err *apiError + a *Authority + kid string + err error + code int } tests := map[string]func(t *testing.T) *ek{ "ok": func(t *testing.T) *ek { @@ -32,10 +34,10 @@ func TestGetEncryptedKey(t *testing.T) { a, err := New(c) assert.FatalError(t, err) return &ek{ - a: a, - kid: "foo", - err: &apiError{errors.Errorf("encrypted key with kid foo was not found"), - http.StatusNotFound, apiCtx{}}, + a: a, + kid: "foo", + err: errors.New("encrypted key with kid foo was not found"), + code: http.StatusNotFound, } }, } @@ -47,14 +49,10 @@ func TestGetEncryptedKey(t *testing.T) { ek, err := tc.a.GetEncryptedKey(tc.kid) if err != nil { if assert.NotNil(t, tc.err) { - switch v := err.(type) { - case *apiError: - assert.HasPrefix(t, v.err.Error(), tc.err.Error()) - assert.Equals(t, v.code, tc.err.code) - assert.Equals(t, v.context, tc.err.context) - default: - t.Errorf("unexpected error type: %T", v) - } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { diff --git a/authority/root.go b/authority/root.go index 3794a6c8..f391997f 100644 --- a/authority/root.go +++ b/authority/root.go @@ -2,23 +2,20 @@ package authority import ( "crypto/x509" - "net/http" - "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" ) // Root returns the certificate corresponding to the given SHA sum argument. func (a *Authority) Root(sum string) (*x509.Certificate, error) { val, ok := a.certificates.Load(sum) if !ok { - return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum), - http.StatusNotFound, apiCtx{}} + return nil, errs.NotFound("certificate with fingerprint %s was not found", sum) } crt, ok := val.(*x509.Certificate) if !ok { - return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"), - http.StatusInternalServerError, apiCtx{}} + return nil, errs.InternalServer("stored value is not a *x509.Certificate") } return crt, nil } @@ -52,8 +49,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) crt, ok := v.(*x509.Certificate) if !ok { federation = nil - err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"), - http.StatusInternalServerError, apiCtx{}} + err = errs.InternalServer("stored value is not a *x509.Certificate") return false } federation = append(federation, crt) diff --git a/authority/root_test.go b/authority/root_test.go index 4b648d78..a936b66f 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/errs" "github.com/smallstep/cli/crypto/pemutil" ) @@ -16,12 +17,13 @@ func TestRoot(t *testing.T) { a.certificates.Store("invaliddata", "a string") // invalid cert for testing tests := map[string]struct { - sum string - err *apiError + sum string + err error + code int }{ - "not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}}, - "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}}, - "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil}, + "not-found": {"foo", errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound}, + "invalid-stored-certificate": {"invaliddata", errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError}, + "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil, http.StatusOK}, } for name, tc := range tests { @@ -29,14 +31,10 @@ func TestRoot(t *testing.T) { crt, err := a.Root(tc.sum) if err != nil { if assert.NotNil(t, tc.err) { - switch v := err.(type) { - case *apiError: - assert.HasPrefix(t, v.err.Error(), tc.err.Error()) - assert.Equals(t, v.code, tc.err.code) - assert.Equals(t, v.context, tc.err.context) - default: - t.Errorf("unexpected error type: %T", v) - } + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tc.code) + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { diff --git a/authority/ssh.go b/authority/ssh.go index 5d80a427..28066556 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -122,7 +122,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) { // GetSSHConfig returns rendered templates for clients (user) or servers (host). func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) { if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil { - return nil, errs.NotFound(errors.New("getSSHConfig: ssh is not configured")) + return nil, errs.NotFound("getSSHConfig: ssh is not configured") } var ts []templates.Template @@ -136,7 +136,7 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template ts = a.config.Templates.SSH.Host } default: - return nil, errs.BadRequest(errors.Errorf("getSSHConfig: type %s is not valid", typ)) + return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ) } // Merge user and default data @@ -177,13 +177,13 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error } return nil, nil } - return nil, errs.NotFound(errors.New("authority.GetSSHBastion; ssh is not configured")) + return nil, errs.NotFound("authority.GetSSHBastion; ssh is not configured") } // SignSSH creates a signed SSH certificate with the given public key and options. func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - var mods []provisioner.SSHCertificateModifier - var validators []provisioner.SSHCertificateValidator + var mods []provisioner.SSHCertModifier + var validators []provisioner.SSHCertValidator // Set backdate with the configured value opts.Backdate = a.config.AuthorityConfig.Backdate.Duration @@ -191,27 +191,27 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign for _, op := range signOpts { switch o := op.(type) { // modify the ssh.Certificate - case provisioner.SSHCertificateModifier: + case provisioner.SSHCertModifier: mods = append(mods, o) // modify the ssh.Certificate given the SSHOptions - case provisioner.SSHCertificateOptionModifier: + case provisioner.SSHCertOptionModifier: mods = append(mods, o.Option(opts)) // validate the ssh.Certificate - case provisioner.SSHCertificateValidator: + case provisioner.SSHCertValidator: validators = append(validators, o) // validate the given SSHOptions - case provisioner.SSHCertificateOptionsValidator: + case provisioner.SSHCertOptionsValidator: if err := o.Valid(opts); err != nil { - return nil, errs.Forbidden(err) + return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") } default: - return nil, errs.InternalServerError(errors.Errorf("signSSH: invalid extra option type %T", o)) + return nil, errs.InternalServer("signSSH: invalid extra option type %T", o) } } nonce, err := randutil.ASCII(32) if err != nil { - return nil, errs.InternalServerError(err) + return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH") } var serial uint64 @@ -228,13 +228,13 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign // Use opts to modify the certificate if err := opts.Modify(cert); err != nil { - return nil, errs.Forbidden(err) + return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") } // Use provisioner modifiers for _, m := range mods { if err := m.Modify(cert); err != nil { - return nil, errs.Forbidden(err) + return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") } } @@ -243,16 +243,16 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign switch cert.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("signSSH: user certificate signing is not enabled")) + return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled") } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("signSSH: host certificate signing is not enabled")) + return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled") } signer = a.sshCAHostCertSignKey default: - return nil, errs.InternalServerError(errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType)) + return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType) } cert.SignatureKey = signer.PublicKey() @@ -270,7 +270,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign // User provisioners validators for _, v := range validators { if err := v.Valid(cert); err != nil { - return nil, errs.Forbidden(err) + return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") } } @@ -285,7 +285,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) { nonce, err := randutil.ASCII(32) if err != nil { - return nil, errs.InternalServerError(err) + return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH") } var serial uint64 @@ -294,7 +294,7 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) } if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errs.BadRequest(errors.New("rewnewSSH: cannot renew certificate without validity period")) + return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period") } backdate := a.config.AuthorityConfig.Backdate.Duration @@ -321,16 +321,16 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) switch cert.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("renewSSH: user certificate signing is not enabled")) + return nil, errs.NotImplemented("renewSSH: user certificate signing is not enabled") } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("renewSSH: host certificate signing is not enabled")) + return nil, errs.NotImplemented("renewSSH: host certificate signing is not enabled") } signer = a.sshCAHostCertSignKey default: - return nil, errs.InternalServerError(errors.Errorf("renewSSH: unexpected ssh certificate type: %d", cert.CertType)) + return nil, errs.InternalServer("renewSSH: unexpected ssh certificate type: %d", cert.CertType) } cert.SignatureKey = signer.PublicKey() @@ -354,21 +354,21 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) // RekeySSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - var validators []provisioner.SSHCertificateValidator + var validators []provisioner.SSHCertValidator for _, op := range signOpts { switch o := op.(type) { // validate the ssh.Certificate - case provisioner.SSHCertificateValidator: + case provisioner.SSHCertValidator: validators = append(validators, o) default: - return nil, errs.InternalServerError(errors.Errorf("rekeySSH; invalid extra option type %T", o)) + return nil, errs.InternalServer("rekeySSH; invalid extra option type %T", o) } } nonce, err := randutil.ASCII(32) if err != nil { - return nil, errs.InternalServerError(err) + return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH") } var serial uint64 @@ -377,7 +377,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp } if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errs.BadRequest(errors.New("rekeySSH; cannot rekey certificate without validity period")) + return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period") } backdate := a.config.AuthorityConfig.Backdate.Duration @@ -404,16 +404,16 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp switch cert.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("rekeySSH; user certificate signing is not enabled")) + return nil, errs.NotImplemented("rekeySSH; user certificate signing is not enabled") } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("rekeySSH; host certificate signing is not enabled")) + return nil, errs.NotImplemented("rekeySSH; host certificate signing is not enabled") } signer = a.sshCAHostCertSignKey default: - return nil, errs.BadRequest(errors.Errorf("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)) + return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType) } cert.SignatureKey = signer.PublicKey() @@ -431,7 +431,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp // Apply validators from provisioner.. for _, v := range validators { if err := v.Valid(cert); err != nil { - return nil, errs.Forbidden(err) + return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH") } } @@ -445,18 +445,18 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp // SignSSHAddUser signs a certificate that provisions a new user in a server. func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { if a.sshCAUserCertSignKey == nil { - return nil, errs.NotImplemented(errors.New("signSSHAddUser: user certificate signing is not enabled")) + return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled") } if subject.CertType != ssh.UserCert { - return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate is not a user certificate")) + return nil, errs.Forbidden("signSSHAddUser: certificate is not a user certificate") } if len(subject.ValidPrincipals) != 1 { - return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate does not have only one principal")) + return nil, errs.Forbidden("signSSHAddUser: certificate does not have only one principal") } nonce, err := randutil.ASCII(32) if err != nil { - return nil, errs.InternalServerError(err) + return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser") } var serial uint64 diff --git a/authority/ssh_test.go b/authority/ssh_test.go index db5dc85d..cc3f164c 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -80,7 +80,7 @@ func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error { type sshTestOptionsModifier string -func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier { +func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertModifier { return sshTestCertModifier(string(m)) } @@ -492,12 +492,12 @@ func TestAuthority_CheckSSHHost(t *testing.T) { want bool wantErr bool }{ - {"true", fields{true, nil}, args{context.TODO(), "foo.internal.com", ""}, true, false}, - {"false", fields{false, nil}, args{context.TODO(), "foo.internal.com", ""}, false, false}, - {"notImplemented", fields{false, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true}, - {"notImplemented", fields{true, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true}, - {"internal", fields{false, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true}, - {"internal", fields{true, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true}, + {"true", fields{true, nil}, args{context.Background(), "foo.internal.com", ""}, true, false}, + {"false", fields{false, nil}, args{context.Background(), "foo.internal.com", ""}, false, false}, + {"notImplemented", fields{false, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true}, + {"notImplemented", fields{true, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true}, + {"internal", fields{false, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true}, + {"internal", fields{true, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/tls.go b/authority/tls.go index 9199c040..03a9ec33 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -61,7 +61,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption { // Sign creates a signed certificate from a certificate signing request. func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { var ( - opts = []errs.Option{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} + opts = []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)} certValidators = []provisioner.CertificateValidator{} issIdentity = a.intermediateIdentity @@ -81,7 +81,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti case provisioner.ProfileModifier: mods = append(mods, k.Option(signOpts)) default: - return nil, errs.InternalServerError(errors.Errorf("authority.Sign; invalid extra option type %T", k), opts...) + return nil, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]interface{}{k}, opts...)...) } } @@ -131,7 +131,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti // Renew creates a new Certificate identical to the old certificate, except // with a validity window that begins 'now'. func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { - opts := []errs.Option{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())} + opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())} // Check step provisioner extensions if err := a.authorizeRenew(oldCert); err != nil { @@ -237,7 +237,7 @@ type RevokeOptions struct { // // TODO: Add OCSP and CRL support. func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error { - opts := []errs.Option{ + opts := []interface{}{ errs.WithKeyVal("serialNumber", revokeOpts.Serial), errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode), errs.WithKeyVal("reason", revokeOpts.Reason), @@ -281,7 +281,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error var ok bool p, ok = a.provisioners.LoadByToken(token, &claims.Claims) if !ok { - return errs.InternalServerError(errors.Errorf("authority.Revoke; provisioner not found"), opts...) + return errs.InternalServer("authority.Revoke; provisioner not found", opts...) } rci.TokenID, err = p.GetTokenID(revokeOpts.OTT) if err != nil { @@ -309,10 +309,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error case nil: return nil case db.ErrNotImplemented: - return errs.NotImplemented(errors.New("authority.Revoke; no persistence layer configured"), opts...) + return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...) case db.ErrAlreadyExists: - return errs.BadRequest(errors.Errorf("authority.Revoke; certificate with serial "+ - "number %s has already been revoked", rci.Serial), opts...) + return errs.BadRequest("authority.Revoke; certificate with serial "+ + "number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...) default: return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) } diff --git a/ca/client.go b/ca/client.go index e6fdab92..ce936655 100644 --- a/ca/client.go +++ b/ca/client.go @@ -553,7 +553,7 @@ retry: // verify the sha256 sum := sha256.Sum256(root.RootPEM.Raw) if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) { - return nil, errs.BadRequest(errors.New("client.Root; root certificate SHA256 fingerprint do not match")) + return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match") } return &root, nil } @@ -961,8 +961,8 @@ func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrin retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { - return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", u, - errs.WithMessage("Failed to perform POST request to %s", u)) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", + []interface{}{u, errs.WithMessage("Failed to perform POST request to %s", u)}...) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { @@ -974,8 +974,8 @@ retry: } var check api.SSHCheckPrincipalResponse if err := readJSON(resp.Body, &check); err != nil { - return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", u, - errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")) + return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", + []interface{}{u, errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")}) } return &check, nil } diff --git a/ca/client_test.go b/ca/client_test.go index 5b74f5cb..f880c876 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -163,8 +163,8 @@ func TestClient_Version(t *testing.T) { expectedErr error }{ {"ok", ok, 200, false, nil}, - {"500", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, - {"404", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, + {"500", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, + {"404", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -214,7 +214,7 @@ func TestClient_Health(t *testing.T) { expectedErr error }{ {"ok", ok, 200, false, nil}, - {"not ok", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, + {"not ok", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -268,7 +268,7 @@ func TestClient_Root(t *testing.T) { expectedErr error }{ {"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil}, - {"not found", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, + {"not found", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -336,9 +336,9 @@ func TestClient_Sign(t *testing.T) { expectedErr error }{ {"ok", request, ok, 200, false, nil}, - {"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", &api.SignRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -409,8 +409,8 @@ func TestClient_Revoke(t *testing.T) { expectedErr error }{ {"ok", request, ok, 200, false, nil}, - {"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -483,9 +483,9 @@ func TestClient_Renew(t *testing.T) { err error }{ {"ok", ok, 200, false, nil}, - {"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -533,7 +533,7 @@ func TestClient_Provisioners(t *testing.T) { ok := &api.ProvisionersResponse{ Provisioners: provisioner.List{}, } - internalServerError := errs.InternalServerError(fmt.Errorf("Internal Server Error")) + internalServerError := errs.InternalServer("Internal Server Error") tests := []struct { name string @@ -603,7 +603,7 @@ func TestClient_ProvisionerKey(t *testing.T) { err error }{ {"ok", "kid", ok, 200, false, nil}, - {"fail", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, + {"fail", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -665,8 +665,8 @@ func TestClient_Roots(t *testing.T) { err error }{ {"ok", ok, 200, false, nil}, - {"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"bad-request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -724,7 +724,7 @@ func TestClient_Federation(t *testing.T) { err error }{ {"ok", ok, 200, false, nil}, - {"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -786,7 +786,7 @@ func TestClient_SSHRoots(t *testing.T) { err error }{ {"ok", ok, 200, false, nil}, - {"not found", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)}, + {"not found", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) @@ -869,7 +869,7 @@ func Test_parseEndpoint(t *testing.T) { func TestClient_RootFingerprint(t *testing.T) { ok := &api.HealthResponse{Status: "ok"} - nok := errs.InternalServerError(fmt.Errorf("Internal Server Error")) + nok := errs.InternalServer("Internal Server Error") httpsServer := httptest.NewTLSServer(nil) defer httpsServer.Close() @@ -947,7 +947,7 @@ func TestClient_SSHBastion(t *testing.T) { }{ {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil}, {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil}, - {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, } srv := httptest.NewServer(nil) diff --git a/errs/error.go b/errs/error.go index adae017e..2e49d8c5 100644 --- a/errs/error.go +++ b/errs/error.go @@ -62,31 +62,6 @@ type Error struct { Details map[string]interface{} } -// New returns a new Error. If the given error implements the StatusCoder -// interface we will ignore the given status. -func New(status int, err error, opts ...Option) error { - var ( - e *Error - ok bool - ) - if e, ok = err.(*Error); !ok { - if sc, ok := err.(StatusCoder); ok { - e = &Error{Status: sc.StatusCode(), Err: err} - } else { - cause := errors.Cause(err) - if sc, ok := cause.(StatusCoder); ok { - e = &Error{Status: sc.StatusCode(), Err: err} - } else { - e = &Error{Status: status, Err: err} - } - } - } - for _, o := range opts { - o(e) - } - return e -} - // ErrorResponse represents an error in JSON format. type ErrorResponse struct { Status int `json:"status"` @@ -119,10 +94,11 @@ func (e *Error) Message() string { // Wrap returns an error annotating err with a stack trace at the point Wrap is // called, and the supplied message. If err is nil, Wrap returns nil. -func Wrap(status int, e error, m string, opts ...Option) error { +func Wrap(status int, e error, m string, args ...interface{}) error { if e == nil { return nil } + _, opts := splitOptionArgs(args) if err, ok := e.(*Error); ok { err.Err = errors.Wrap(err.Err, m) e = err @@ -138,25 +114,12 @@ func Wrapf(status int, e error, format string, args ...interface{}) error { if e == nil { return nil } - var opts []Option - for i, arg := range args { - // Once we find the first Option, assume that all further arguments are Options. - if _, ok := arg.(Option); ok { - for _, a := range args[i:] { - // Ignore any arguments after the first Option that are not Options. - if opt, ok := a.(Option); ok { - opts = append(opts, opt) - } - } - args = args[:i] - break - } - } + as, opts := splitOptionArgs(args) if err, ok := e.(*Error); ok { err.Err = errors.Wrapf(err.Err, format, args...) e = err } else { - e = errors.Wrapf(e, format, args...) + e = errors.Wrapf(e, format, as...) } return StatusCodeError(status, e, opts...) } @@ -201,24 +164,24 @@ type Messenger interface { func StatusCodeError(code int, e error, opts ...Option) error { switch code { case http.StatusBadRequest: - return BadRequest(e, opts...) + return BadRequestErr(e, opts...) case http.StatusUnauthorized: - return Unauthorized(e, opts...) + return UnauthorizedErr(e, opts...) case http.StatusForbidden: - return Forbidden(e, opts...) + return ForbiddenErr(e, opts...) case http.StatusInternalServerError: - return InternalServerError(e, opts...) + return InternalServerErr(e, opts...) case http.StatusNotImplemented: - return NotImplemented(e, opts...) + return NotImplementedErr(e, opts...) default: - return UnexpectedError(code, e, opts...) + return UnexpectedErr(code, e, opts...) } } var ( seeLogs = "Please see the certificate authority logs for more info." // BadRequestDefaultMsg 400 default msg - BadRequestDefaultMsg = "The request could not be completed due to being poorly formatted or missing critical data. " + seeLogs + BadRequestDefaultMsg = "The request could not be completed; malformed or missing data" + seeLogs // UnauthorizedDefaultMsg 401 default msg UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs // ForbiddenDefaultMsg 403 default msg @@ -231,46 +194,142 @@ var ( NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs ) -// InternalServerError returns a 500 error with the given error. -func InternalServerError(err error, opts ...Option) error { - opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg)) - return New(http.StatusInternalServerError, err, opts...) +// splitOptionArgs splits the variadic length args into string formatting args +// and Option(s) to apply to an Error. +func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { + indexOptionStart := -1 + for i, a := range args { + if _, ok := a.(Option); ok { + indexOptionStart = i + break + } + } + + if indexOptionStart < 0 { + return args, []Option{} + } + opts := []Option{} + // Ignore any non-Option args that come after the first Option. + for _, o := range args[indexOptionStart:] { + if opt, ok := o.(Option); ok { + opts = append(opts, opt) + } + } + return args[:indexOptionStart], opts } -// NotImplemented returns a 501 error with the given error. -func NotImplemented(err error, opts ...Option) error { +// NewErr returns a new Error. If the given error implements the StatusCoder +// interface we will ignore the given status. +func NewErr(status int, err error, opts ...Option) error { + var ( + e *Error + ok bool + ) + if e, ok = err.(*Error); !ok { + if sc, ok := err.(StatusCoder); ok { + e = &Error{Status: sc.StatusCode(), Err: err} + } else { + cause := errors.Cause(err) + if sc, ok := cause.(StatusCoder); ok { + e = &Error{Status: sc.StatusCode(), Err: err} + } else { + e = &Error{Status: status, Err: err} + } + } + } + for _, o := range opts { + o(e) + } + return e +} + +// Errorf creates a new error using the given format and status code. +func Errorf(code int, format string, args ...interface{}) error { + as, opts := splitOptionArgs(args) opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg)) - return New(http.StatusNotImplemented, err, opts...) + e := &Error{Status: code, Err: fmt.Errorf(format, as...)} + for _, o := range opts { + o(e) + } + return e } -// BadRequest returns an 400 error with the given error. -func BadRequest(err error, opts ...Option) error { +// InternalServer creates a 500 error with the given format and arguments. +func InternalServer(format string, args ...interface{}) error { + args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg)) + return Errorf(http.StatusInternalServerError, format, args...) +} + +// InternalServerErr returns a 500 error with the given error. +func InternalServerErr(err error, opts ...Option) error { + opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg)) + return NewErr(http.StatusInternalServerError, err, opts...) +} + +// NotImplemented creates a 501 error with the given format and arguments. +func NotImplemented(format string, args ...interface{}) error { + args = append(args, withDefaultMessage(NotImplementedDefaultMsg)) + return Errorf(http.StatusNotImplemented, format, args...) +} + +// NotImplementedErr returns a 501 error with the given error. +func NotImplementedErr(err error, opts ...Option) error { + opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg)) + return NewErr(http.StatusNotImplemented, err, opts...) +} + +// BadRequest creates a 400 error with the given format and arguments. +func BadRequest(format string, args ...interface{}) error { + args = append(args, withDefaultMessage(BadRequestDefaultMsg)) + return Errorf(http.StatusBadRequest, format, args...) +} + +// BadRequestErr returns an 400 error with the given error. +func BadRequestErr(err error, opts ...Option) error { opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) - return New(http.StatusBadRequest, err, opts...) + return NewErr(http.StatusBadRequest, err, opts...) } -// Unauthorized returns an 401 error with the given error. -func Unauthorized(err error, opts ...Option) error { +// Unauthorized creates a 401 error with the given format and arguments. +func Unauthorized(format string, args ...interface{}) error { + args = append(args, withDefaultMessage(UnauthorizedDefaultMsg)) + return Errorf(http.StatusUnauthorized, format, args...) +} + +// UnauthorizedErr returns an 401 error with the given error. +func UnauthorizedErr(err error, opts ...Option) error { opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg)) - return New(http.StatusUnauthorized, err, opts...) + return NewErr(http.StatusUnauthorized, err, opts...) } -// Forbidden returns an 403 error with the given error. -func Forbidden(err error, opts ...Option) error { +// Forbidden creates a 403 error with the given format and arguments. +func Forbidden(format string, args ...interface{}) error { + args = append(args, withDefaultMessage(ForbiddenDefaultMsg)) + return Errorf(http.StatusForbidden, format, args...) +} + +// ForbiddenErr returns an 403 error with the given error. +func ForbiddenErr(err error, opts ...Option) error { opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg)) - return New(http.StatusForbidden, err, opts...) + return NewErr(http.StatusForbidden, err, opts...) } -// NotFound returns an 404 error with the given error. -func NotFound(err error, opts ...Option) error { +// NotFound creates a 404 error with the given format and arguments. +func NotFound(format string, args ...interface{}) error { + args = append(args, withDefaultMessage(NotFoundDefaultMsg)) + return Errorf(http.StatusNotFound, format, args...) +} + +// NotFoundErr returns an 404 error with the given error. +func NotFoundErr(err error, opts ...Option) error { opts = append(opts, withDefaultMessage(NotFoundDefaultMsg)) - return New(http.StatusNotFound, err, opts...) + return NewErr(http.StatusNotFound, err, opts...) } -// UnexpectedError will be used when the certificate authority makes an outgoing +// UnexpectedErr will be used when the certificate authority makes an outgoing // request and receives an unhandled status code. -func UnexpectedError(code int, err error, opts ...Option) error { +func UnexpectedErr(code int, err error, opts ...Option) error { opts = append(opts, withDefaultMessage("The certificate authority received an "+ "unexpected HTTP status code - '%d'. "+seeLogs, code)) - return New(code, err, opts...) + return NewErr(code, err, opts...) }