Simplify statuscoder error generators.

This commit is contained in:
max furman 2020-01-23 22:04:34 -08:00
parent dccbdf3a90
commit 1cb8bb3ae1
45 changed files with 483 additions and 441 deletions

View file

@ -63,6 +63,7 @@ issues:
- declaration of "err" shadows declaration at line
- should have a package comment, unless it's in another file for this package
- error strings should not be capitalized or end with punctuation or a newline
- Wrapf call needs 1 arg but has 2 args
# golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration
service:

View file

@ -295,7 +295,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the
cert, err := h.Authority.Root(sum)
if err != nil {
WriteError(w, errs.NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI)))
WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return
}
@ -314,13 +314,13 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := parseCursor(r)
if err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
p, next, err := h.Authority.GetProvisioners(cursor, limit)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
JSON(w, &ProvisionersResponse{
@ -334,7 +334,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid)
if err != nil {
WriteError(w, errs.NotFound(err))
WriteError(w, errs.NotFoundErr(err))
return
}
JSON(w, &ProvisionerKeyResponse{key})
@ -344,7 +344,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
@ -362,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation()
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -915,7 +915,7 @@ func Test_caHandler_Renew(t *testing.T) {
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, errs.Forbidden(fmt.Errorf("an error")), http.StatusForbidden},
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
}
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
@ -1010,10 +1010,10 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err)
}
expectedError400 := errs.BadRequest(errors.New("force"))
expectedError400 := errs.BadRequest("force")
expectedError400Bytes, err := json.Marshal(expectedError400)
assert.FatalError(t, err)
expectedError500 := errs.InternalServerError(errors.New("force"))
expectedError500 := errs.InternalServer("force")
expectedError500Bytes, err := json.Marshal(expectedError500)
assert.FatalError(t, err)
for _, tt := range tests {
@ -1082,7 +1082,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
}
expected := []byte(`{"key":"` + privKey + `"}`)
expectedError404 := errs.NotFound(errors.New("force"))
expectedError404 := errs.NotFound("force")
expectedError404Bytes, err := json.Marshal(expectedError404)
assert.FatalError(t, err)

View file

@ -3,7 +3,6 @@ package api
import (
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
)
@ -11,7 +10,7 @@ import (
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest(errors.New("missing peer certificate")))
WriteError(w, errs.BadRequest("missing peer certificate"))
return
}
@ -22,7 +21,7 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 0 {
if len(certChainPEM) > 1 {
caPEM = certChainPEM[1]
}

View file

@ -4,7 +4,6 @@ import (
"context"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
@ -30,13 +29,13 @@ type RevokeRequest struct {
// or an error if something is wrong.
func (r *RevokeRequest) Validate() (err error) {
if r.Serial == "" {
return errs.BadRequest(errors.New("missing serial"))
return errs.BadRequest("missing serial")
}
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return errs.BadRequest(errors.New("reasonCode out of bounds"))
return errs.BadRequest("reasonCode out of bounds")
}
if !r.Passive {
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
return errs.NotImplemented("non-passive revocation not implemented")
}
return
@ -50,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) {
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
@ -72,7 +71,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
if len(body.OTT) > 0 {
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
@ -81,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number
// being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest(errors.New("missing ott or peer certificate")))
WriteError(w, errs.BadRequest("missing ott or peer certificate"))
return
}
opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest(errors.New("revoke: serial number in mtls certificate different than body")))
WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body"))
return
}
// TODO: should probably be checking if the certificate was revoked here.
@ -97,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
}
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -190,7 +190,7 @@ func Test_caHandler_Revoke(t *testing.T) {
return nil, nil
},
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
return errs.InternalServerError(errors.New("force"))
return errs.InternalServer("force")
},
},
}

View file

@ -4,7 +4,6 @@ import (
"crypto/tls"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/tlsutil"
@ -22,13 +21,13 @@ type SignRequest struct {
// or an error if something is wrong.
func (s *SignRequest) Validate() error {
if s.CsrPEM.CertificateRequest == nil {
return errs.BadRequest(errors.New("missing csr"))
return errs.BadRequest("missing csr")
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.BadRequest(errors.Wrap(err, "invalid csr"))
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
}
if s.OTT == "" {
return errs.BadRequest(errors.New("missing ott"))
return errs.BadRequest("missing ott")
}
return nil
@ -49,7 +48,7 @@ type SignResponse struct {
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
@ -66,18 +65,18 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
certChainPEM := certChainToPEM(certChain)
var caPEM Certificate
if len(certChainPEM) > 0 {
if len(certChainPEM) > 1 {
caPEM = certChainPEM[1]
}
logCertificate(w, certChain[0])

View file

@ -249,19 +249,19 @@ type SSHBastionResponse struct {
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
return
}
@ -269,7 +269,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey"))
return
}
}
@ -285,13 +285,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
@ -299,7 +299,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
addUserCertificate = &SSHCertificate{addUserCert}
@ -320,12 +320,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
certChain, err := h.Authority.Sign(cr, opts, signOpts...)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
identityCertificate = certChainToPEM(certChain)
@ -343,12 +343,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots()
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound(errors.New("no keys found")))
WriteError(w, errs.NotFound("no keys found"))
return
}
@ -368,12 +368,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation()
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound(errors.New("no keys found")))
WriteError(w, errs.NotFound("no keys found"))
return
}
@ -393,17 +393,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
@ -414,7 +414,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert:
config.HostTemplates = ts
default:
WriteError(w, errs.InternalServerError(errors.New("it should hot get here")))
WriteError(w, errs.InternalServer("it should hot get here"))
return
}
@ -429,13 +429,13 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHCheckPrincipalResponse{
@ -452,7 +452,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts(cert)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHGetHostsResponse{
@ -464,17 +464,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
return
}

View file

@ -40,42 +40,42 @@ type SSHRekeyResponse struct {
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
return
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
}
newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
identity, err := h.renewIdentityCertificate(r)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -36,36 +36,36 @@ type SSHRenewResponse struct {
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequest(err))
WriteError(w, errs.BadRequestErr(err))
return
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, errs.InternalServerError(err))
WriteError(w, errs.InternalServerErr(err))
}
newCert, err := h.Authority.RenewSSH(oldCert)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}
identity, err := h.renewIdentityCertificate(r)
if err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -4,7 +4,6 @@ import (
"context"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
@ -30,16 +29,16 @@ type SSHRevokeRequest struct {
// or an error if something is wrong.
func (r *SSHRevokeRequest) Validate() (err error) {
if r.Serial == "" {
return errs.BadRequest(errors.New("missing serial"))
return errs.BadRequest("missing serial")
}
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return errs.BadRequest(errors.New("reasonCode out of bounds"))
return errs.BadRequest("reasonCode out of bounds")
}
if !r.Passive {
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
return errs.NotImplemented("non-passive revocation not implemented")
}
if len(r.OTT) == 0 {
return errs.BadRequest(errors.New("missing ott"))
return errs.BadRequest("missing ott")
}
return
}
@ -50,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
@ -71,13 +70,13 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
// otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.Unauthorized(err))
WriteError(w, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.Forbidden(err))
WriteError(w, errs.ForbiddenErr(err))
return
}

View file

@ -6,7 +6,6 @@ import (
"log"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
)
@ -69,7 +68,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
// pointed by v.
func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequest(errors.Wrap(err, "error decoding json"))
return errs.Wrap(http.StatusBadRequest, err, "error decoding json")
}
return nil
}

View file

@ -6,7 +6,6 @@ import (
"net/http"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/jose"
@ -58,15 +57,15 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// This check is meant as a stopgap solution to the current lack of a persistence layer.
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
return nil, errs.Unauthorized(errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"))
return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority")
}
}
// This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok {
return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: provisioner "+
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")))
return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
// Store the token to protect against reuse unless it's skipped.
@ -78,7 +77,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
"authority.authorizeToken: failed when attempting to store token")
}
if !ok {
return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: token already used"))
return nil, errs.Unauthorized("authority.authorizeToken: token already used")
}
}
}
@ -89,7 +88,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// Authorize grabs the method from the context and authorizes the request by
// validating the one-time-token.
func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) {
var opts = []errs.Option{errs.WithKeyVal("token", token)}
var opts = []interface{}{errs.WithKeyVal("token", token)}
switch m := provisioner.MethodFromContext(ctx); m {
case provisioner.SignMethod:
@ -99,13 +98,13 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...)
case provisioner.SSHSignMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...)
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
}
_, err := a.authorizeSSHSign(ctx, token)
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
case provisioner.SSHRenewMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...)
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
}
_, err := a.authorizeSSHRenew(ctx, token)
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
@ -113,12 +112,12 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...)
case provisioner.SSHRekeyMethod:
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...)
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
}
_, signOpts, err := a.authorizeSSHRekey(ctx, token)
return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
default:
return nil, errs.InternalServerError(errors.Errorf("authority.Authorize; method %d is not supported", m), opts...)
return nil, errs.InternalServer("authority.Authorize; method %d is not supported", append([]interface{}{m}, opts...)...)
}
}
@ -165,7 +164,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
//
// TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
var opts = []errs.Option{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
// Check the passive revocation table.
isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String())
@ -173,12 +172,12 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
}
if isRevoked {
return errs.Unauthorized(errors.New("authority.authorizeRenew: certificate has been revoked"), opts...)
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
}
p, ok := a.provisioners.LoadByCertificate(cert)
if !ok {
return errs.Unauthorized(errors.New("authority.authorizeRenew: provisioner not found"), opts...)
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
}
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)

View file

@ -180,7 +180,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
_, err = _a.authorizeToken(context.TODO(), raw)
_, err = _a.authorizeToken(context.Background(), raw)
assert.FatalError(t, err)
return &authorizeTest{
auth: _a,
@ -268,7 +268,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
p, err := tc.auth.authorizeToken(context.TODO(), tc.token)
p, err := tc.auth.authorizeToken(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
@ -355,7 +355,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
if err := tc.auth.authorizeRevoke(context.TODO(), tc.token); err != nil {
if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")

View file

@ -80,7 +80,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// certificate was configured to allow renewals.
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()))
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())
}
return nil
}

View file

@ -306,7 +306,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// certificate was configured to allow renewals.
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID()))
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())
}
return nil
}
@ -353,7 +353,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token")
}
if len(jwt.Headers) == 0 {
return nil, errs.InternalServerError(errors.New("aws.authorizeToken; error parsing token, header is missing"))
return nil, errs.InternalServer("aws.authorizeToken; error parsing token, header is missing")
}
var unsafeClaims awsPayload
@ -378,13 +378,13 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
switch {
case doc.AccountID == "":
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document accountId cannot be empty"))
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document accountId cannot be empty")
case doc.InstanceID == "":
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty"))
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document instanceId cannot be empty")
case doc.PrivateIP == "":
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty"))
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document privateIp cannot be empty")
case doc.Region == "":
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document region cannot be empty"))
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document region cannot be empty")
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -399,7 +399,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// validate audiences with the defaults
if !matchesAudience(payload.Audience, p.audiences.Sign) {
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)"))
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
}
// Validate subject, it has to be known if disableCustomSANs is enabled
@ -407,7 +407,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
if payload.Subject != doc.InstanceID &&
payload.Subject != doc.PrivateIP &&
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)"))
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
}
}
@ -421,14 +421,14 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
}
}
if !found {
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid"))
return nil, errs.Unauthorized("aws.authorizeToken; invalid aws identity document - accountId is not valid")
}
}
// validate instance age
if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(doc.PendingTime) > d {
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document pendingTime is too old"))
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document pendingTime is too old")
}
}
@ -439,7 +439,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID()))
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())
}
claims, err := p.authorizeToken(token)
if err != nil {
@ -462,7 +462,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
},
}
// Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
@ -474,8 +474,8 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -704,7 +704,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)

View file

@ -210,14 +210,14 @@ func (p *Azure) Init(config Config) (err error) {
return nil
}
// authorizeToken returs the claims, name, group, error.
// authorizeToken returns the claims, name, group, error.
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
}
if len(jwt.Headers) == 0 {
return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token missing header"))
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
}
var found bool
@ -230,7 +230,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
}
}
if !found {
return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; cannot validate azure token"))
return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
}
if err := claims.ValidateWithLeeway(jose.Expected{
@ -243,12 +243,12 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
// Validate TenantID
if claims.TenantID != p.TenantID {
return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)"))
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
}
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 {
return nil, "", "", errs.Unauthorized(errors.Errorf("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID))
return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
}
group, name := re[2], re[3]
return &claims, name, group, nil
@ -272,7 +272,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}
}
if !found {
return nil, errs.Unauthorized(errors.New("azure.AuthorizeSign; azure token validation failed - invalid resource group"))
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid resource group")
}
}
@ -302,7 +302,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// certificate was configured to allow renewals.
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID()))
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())
}
return nil
}
@ -310,7 +310,7 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID()))
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())
}
_, name, _, err := p.authorizeToken(token)
@ -328,7 +328,7 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
Principals: []string{name},
}
// Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
@ -340,9 +340,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -488,7 +488,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)

View file

@ -243,7 +243,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID()))
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())
}
return nil
}
@ -264,7 +264,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token")
}
if len(jwt.Headers) == 0 {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; error parsing gcp token - header is missing"))
return nil, errs.Unauthorized("gcp.authorizeToken; error parsing gcp token - header is missing")
}
var found bool
@ -278,7 +278,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
}
}
if !found {
return nil, errs.Unauthorized(errors.Errorf("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid))
return nil, errs.Unauthorized("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid)
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -293,7 +293,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)"))
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
}
// validate subject (service account)
@ -306,7 +306,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
}
}
if !found {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim"))
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid subject claim")
}
}
@ -320,26 +320,26 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
}
}
if !found {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid project id"))
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid project id")
}
}
// validate instance age
if d := p.InstanceAge.Value(); d > 0 {
if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d {
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old"))
return nil, errs.Unauthorized("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old")
}
}
switch {
case claims.Google.ComputeEngine.InstanceID == "":
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty"))
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty")
case claims.Google.ComputeEngine.InstanceName == "":
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty"))
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty")
case claims.Google.ComputeEngine.ProjectID == "":
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty"))
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty")
case claims.Google.ComputeEngine.Zone == "":
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty"))
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty")
}
return &claims, nil
@ -348,7 +348,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID()))
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())
}
claims, err := p.authorizeToken(token)
if err != nil {
@ -371,7 +371,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
},
}
// Validate user options
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
// Set defaults if not given as user options
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
@ -383,8 +383,8 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -680,7 +680,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)

View file

@ -120,12 +120,12 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) {
return nil, errs.Unauthorized(errors.Errorf("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s",
audiences, claims.Audience))
return nil, errs.Unauthorized("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s",
audiences, claims.Audience)
}
if claims.Subject == "" {
return nil, errs.Unauthorized(errors.New("jwk.authorizeToken; jwk token subject cannot be empty"))
return nil, errs.Unauthorized("jwk.authorizeToken; jwk token subject cannot be empty")
}
return &claims, nil
@ -173,7 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID()))
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())
}
return nil
}
@ -181,20 +181,20 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID()))
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())
}
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
}
if claims.Step == nil || claims.Step.SSH == nil {
return nil, errs.Unauthorized(errors.New("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token"))
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token")
}
opts := claims.Step.SSH
signOptions := []SignOption{
// validates user's SSHOptions with the ones in the token
sshCertificateOptionsValidator(*opts),
sshCertOptionsValidator(*opts),
}
t := now()
@ -231,9 +231,9 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -222,7 +222,7 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token); err != nil {
if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -337,7 +337,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)

View file

@ -149,7 +149,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
claims k8sSAPayload
)
if p.pubKeys == nil {
return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented"))
return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented")
/* NOTE: We plan to support the TokenReview API in a future release.
Below is some code that should be useful when we prioritize
this integration.
@ -177,7 +177,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
}
}
if !valid {
return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims"))
return nil, errs.Unauthorized("k8ssa.authorizeToken; error validating k8sSA token and extracting claims")
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no
@ -189,7 +189,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
}
if claims.Subject == "" {
return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA token subject cannot be empty"))
return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA token subject cannot be empty")
}
return &claims, nil
@ -221,7 +221,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()))
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())
}
return nil
}
@ -229,7 +229,7 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
// AuthorizeSSHSign validates an request for an SSH certificate.
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()))
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())
}
if _, err := p.authorizeToken(token, p.audiences.SSHSign); err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
@ -246,9 +246,9 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -363,10 +363,10 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case sshCertDefaultsModifier:
assert.Equals(t, v.CertType, SSHUserCert)
case *sshDefaultExtensionModifier:
case *sshCertificateValidityValidator:
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
case *sshDefaultPublicKeyValidator:
case *sshCertificateDefaultValidator:
case *sshCertDefaultValidator:
case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.claimer)
default:

View file

@ -14,8 +14,8 @@ func Test_noop(t *testing.T) {
assert.Equals(t, "noop", p.GetName())
assert.Equals(t, noopType, p.GetType())
assert.Equals(t, nil, p.Init(Config{}))
assert.Equals(t, nil, p.AuthorizeRenew(context.TODO(), &x509.Certificate{}))
assert.Equals(t, nil, p.AuthorizeRevoke(context.TODO(), "foo"))
assert.Equals(t, nil, p.AuthorizeRenew(context.Background(), &x509.Certificate{}))
assert.Equals(t, nil, p.AuthorizeRevoke(context.Background(), "foo"))
kid, key, ok := p.GetEncryptedKey()
assert.Equals(t, "", kid)

View file

@ -195,12 +195,12 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
// Validate azp if present
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: invalid azp"))
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: invalid azp")
}
// Enforce an email claim
if p.Email == "" {
return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email not found"))
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email not found")
}
// Validate domains (case-insensitive)
@ -214,7 +214,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
}
}
if !found {
return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email is not allowed"))
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email is not allowed")
}
}
@ -230,7 +230,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
}
}
if !found {
return errs.Unauthorized(errors.New("validatePayload: oidc token payload validation failed: invalid group"))
return errs.Unauthorized("validatePayload: oidc token payload validation failed: invalid group")
}
}
@ -263,7 +263,7 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
}
}
if !found {
return nil, errs.Unauthorized(errors.New("oidc.AuthorizeToken; cannot validate oidc token"))
return nil, errs.Unauthorized("oidc.AuthorizeToken; cannot validate oidc token")
}
if err := o.ValidatePayload(claims); err != nil {
@ -286,7 +286,7 @@ func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
if o.IsAdmin(claims.Email) {
return nil
}
return errs.Unauthorized(errors.New("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token"))
return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
}
// AuthorizeSign validates the given token.
@ -318,7 +318,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// certificate was configured to allow renewals.
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if o.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID()))
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())
}
return nil
}
@ -326,7 +326,7 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !o.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID()))
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())
}
claims, err := o.authorizeToken(token)
if err != nil {
@ -352,7 +352,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Non-admin users can only use principals returned by the identityFunc, and
// can only sign user certificates.
if !o.IsAdmin(claims.Email) {
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
}
// Default to a user certificate with usernames as principals if those options
@ -367,9 +367,9 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{o.claimer},
&sshCertValidityValidator{o.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}
@ -382,7 +382,7 @@ func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
// Only admins can revoke certificates.
if !o.IsAdmin(claims.Email) {
return errs.Unauthorized(errors.New("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token"))
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
}
return nil
}

View file

@ -284,43 +284,43 @@ type base struct{}
// AuthorizeSign returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing x509 Certificates.
func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSign not implemented"))
return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented")
}
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking x509 Certificates.
func (b *base) AuthorizeRevoke(ctx context.Context, token string) error {
return errs.Unauthorized(errors.New("provisioner.AuthorizeRevoke not implemented"))
return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented")
}
// AuthorizeRenew returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing x509 Certificates.
func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
return errs.Unauthorized(errors.New("provisioner.AuthorizeRenew not implemented"))
return errs.Unauthorized("provisioner.AuthorizeRenew not implemented")
}
// AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing SSH Certificates.
func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHSign not implemented"))
return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented")
}
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking SSH Certificates.
func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRevoke not implemented"))
return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented")
}
// AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing SSH Certificates.
func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRenew not implemented"))
return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented")
}
// AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite
// this method if they will support authorizing tokens for rekeying SSH Certificates.
func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
return nil, nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRekey not implemented"))
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
}
// Identity is the type representing an externally supplied identity that is used

View file

@ -19,29 +19,29 @@ const (
SSHHostCert = "host"
)
// SSHCertificateModifier is the interface used to change properties in an SSH
// SSHCertModifier is the interface used to change properties in an SSH
// certificate.
type SSHCertificateModifier interface {
type SSHCertModifier interface {
SignOption
Modify(cert *ssh.Certificate) error
}
// SSHCertificateOptionModifier is the interface used to add custom options used
// SSHCertOptionModifier is the interface used to add custom options used
// to modify the SSH certificate.
type SSHCertificateOptionModifier interface {
type SSHCertOptionModifier interface {
SignOption
Option(o SSHOptions) SSHCertificateModifier
Option(o SSHOptions) SSHCertModifier
}
// SSHCertificateValidator is the interface used to validate an SSH certificate.
type SSHCertificateValidator interface {
// SSHCertValidator is the interface used to validate an SSH certificate.
type SSHCertValidator interface {
SignOption
Valid(cert *ssh.Certificate) error
}
// SSHCertificateOptionsValidator is the interface used to validate the custom
// SSHCertOptionsValidator is the interface used to validate the custom
// options used to modify the SSH certificate.
type SSHCertificateOptionsValidator interface {
type SSHCertOptionsValidator interface {
SignOption
Valid(got SSHOptions) error
}
@ -69,7 +69,7 @@ func (o SSHOptions) Type() uint32 {
return sshCertTypeUInt32(o.CertType)
}
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate.
// Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
func (o SSHOptions) Modify(cert *ssh.Certificate) error {
switch o.CertType {
case "": // ignore
@ -116,7 +116,7 @@ func (o SSHOptions) match(got SSHOptions) error {
return nil
}
// sshCertPrincipalsModifier is an SSHCertificateModifier that sets the
// sshCertPrincipalsModifier is an SSHCertModifier that sets the
// principals to the SSH certificate.
type sshCertPrincipalsModifier []string
@ -126,7 +126,7 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshCertKeyIDModifier is an SSHCertificateModifier that sets the given
// sshCertKeyIDModifier is an SSHCertModifier that sets the given
// Key ID in the SSH certificate.
type sshCertKeyIDModifier string
@ -135,7 +135,7 @@ func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshCertTypeModifier is an SSHCertificateModifier that sets the
// sshCertTypeModifier is an SSHCertModifier that sets the
// certificate type.
type sshCertTypeModifier string
@ -145,7 +145,7 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshCertValidAfterModifier is an SSHCertificateModifier that sets the
// sshCertValidAfterModifier is an SSHCertModifier that sets the
// ValidAfter in the SSH certificate.
type sshCertValidAfterModifier uint64
@ -154,7 +154,7 @@ func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshCertValidBeforeModifier is an SSHCertificateModifier that sets the
// sshCertValidBeforeModifier is an SSHCertModifier that sets the
// ValidBefore in the SSH certificate.
type sshCertValidBeforeModifier uint64
@ -163,11 +163,11 @@ func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshCertDefaultsModifier implements a SSHCertificateModifier that
// sshCertDefaultsModifier implements a SSHCertModifier that
// modifies the certificate with the given options if they are not set.
type sshCertDefaultsModifier SSHOptions
// Modify implements the SSHCertificateModifier interface.
// Modify implements the SSHCertModifier interface.
func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error {
if cert.CertType == 0 {
cert.CertType = sshCertTypeUInt32(m.CertType)
@ -184,7 +184,7 @@ func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error {
return nil
}
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets
// sshDefaultExtensionModifier implements an SSHCertModifier that sets
// the default extensions in an SSH certificate.
type sshDefaultExtensionModifier struct{}
@ -208,14 +208,14 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
}
}
// sshDefaultDuration is an SSHCertificateModifier that sets the certificate
// sshDefaultDuration is an SSHCertModifier that sets the certificate
// ValidAfter and ValidBefore if they have not been set. It will fail if a
// CertType has not been set or is not valid.
type sshDefaultDuration struct {
*Claimer
}
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertificateModifier {
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertModifier {
return sshModifierFunc(func(cert *ssh.Certificate) error {
d, err := m.DefaultSSHCertDuration(cert.CertType)
if err != nil {
@ -248,7 +248,7 @@ type sshLimitDuration struct {
NotAfter time.Time
}
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertModifier {
if m.NotAfter.IsZero() {
defaultDuration := &sshDefaultDuration{m.Claimer}
return defaultDuration.Option(o)
@ -295,22 +295,22 @@ func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
})
}
// sshCertificateOptionsValidator validates the user SSHOptions with the ones
// sshCertOptionsValidator validates the user SSHOptions with the ones
// usually present in the token.
type sshCertificateOptionsValidator SSHOptions
type sshCertOptionsValidator SSHOptions
// Valid implements SSHCertificateOptionsValidator and returns nil if both
// Valid implements SSHCertOptionsValidator and returns nil if both
// SSHOptions match.
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error {
func (v sshCertOptionsValidator) Valid(got SSHOptions) error {
want := SSHOptions(v)
return want.match(got)
}
type sshCertificateValidityValidator struct {
type sshCertValidityValidator struct {
*Claimer
}
func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate) error {
switch {
case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0")
@ -355,12 +355,12 @@ func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
}
}
// sshCertificateDefaultValidator implements a simple validator for all the
// sshCertDefaultValidator implements a simple validator for all the
// fields in the SSH certificate.
type sshCertificateDefaultValidator struct{}
type sshCertDefaultValidator struct{}
// Valid returns an error if the given certificate does not contain the necessary fields.
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate) error {
switch {
case len(cert.Nonce) == 0:
return errors.New("ssh certificate nonce cannot be empty")

View file

@ -489,12 +489,12 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
}
}
func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
func Test_sshCertDefaultValidator_Valid(t *testing.T) {
pub, _, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
sshPub, err := ssh.NewPublicKey(pub)
assert.FatalError(t, err)
v := sshCertificateDefaultValidator{}
v := sshCertDefaultValidator{}
tests := []struct {
name string
cert *ssh.Certificate
@ -670,10 +670,10 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
}
}
func Test_sshCertificateValidityValidator(t *testing.T) {
func Test_sshCertValidityValidator(t *testing.T) {
p, err := generateX5C(nil)
assert.FatalError(t, err)
v := sshCertificateValidityValidator{p.claimer}
v := sshCertValidityValidator{p.claimer}
n := now()
tests := []struct {
name string
@ -992,7 +992,7 @@ func Test_sshLimitDuration_Option(t *testing.T) {
name string
fields fields
args args
want SSHCertificateModifier
want SSHCertModifier
}{
// TODO: Add test cases.
}

View file

@ -45,22 +45,22 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
return nil, err
}
var mods []SSHCertificateModifier
var validators []SSHCertificateValidator
var mods []SSHCertModifier
var validators []SSHCertValidator
for _, op := range signOpts {
switch o := op.(type) {
// modify the ssh.Certificate
case SSHCertificateModifier:
case SSHCertModifier:
mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions
case SSHCertificateOptionModifier:
case SSHCertOptionModifier:
mods = append(mods, o.Option(opts))
// validate the ssh.Certificate
case SSHCertificateValidator:
case SSHCertValidator:
validators = append(validators, o)
// validate the given SSHOptions
case SSHCertificateOptionsValidator:
case SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil {
return nil, err
}

View file

@ -112,20 +112,20 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
return nil, errs.Wrap(http.StatusInternalServerError, err,
"sshpop.authorizeToken; error checking checking sshpop cert revocation")
} else if isRevoked {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate is revoked"))
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked")
}
// Check validity period of the certificate.
n := time.Now()
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future"))
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
}
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past"))
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
}
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
if !ok {
return nil, errs.InternalServerError(errors.New("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey"))
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
}
pubKey := sshCryptoPubKey.CryptoPublicKey()
@ -146,7 +146,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
}
}
if !found {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate"))
return nil, errs.Unauthorized("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate")
}
// Using the ssh certificates key to validate the claims accomplishes two
@ -170,12 +170,12 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
// validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) {
return nil, errs.Unauthorized(errors.Errorf("sshpop.authorizeToken; sshpop token has invalid audience "+
"claim (aud): expected %s, but got %s", audiences, claims.Audience))
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token has invalid audience "+
"claim (aud): expected %s, but got %s", audiences, claims.Audience)
}
if claims.Subject == "" {
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty"))
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token subject cannot be empty")
}
claims.sshCert = sshCert
@ -190,8 +190,8 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
}
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
return errs.BadRequest(errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject " +
"must be equivalent to sshpop certificate serial number"))
return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " +
"must be equivalent to sshpop certificate serial number")
}
return nil
}
@ -204,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
}
if claims.sshCert.CertType != ssh.HostCert {
return nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"))
return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")
}
return claims.sshCert, nil
@ -219,15 +219,15 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
}
if claims.sshCert.CertType != ssh.HostCert {
return nil, nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"))
return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")
}
return claims.sshCert, []SignOption{
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
}, nil
}

View file

@ -564,8 +564,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
for _, o := range opts {
switch v := o.(type) {
case *sshDefaultPublicKeyValidator:
case *sshCertificateDefaultValidator:
case *sshCertificateValidityValidator:
case *sshCertDefaultValidator:
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))

View file

@ -136,7 +136,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
leaf := verifiedChains[0][0]
if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 {
return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature"))
return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")
}
// Using the leaf certificates key to validate the claims accomplishes two
@ -160,12 +160,12 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// validate audiences with the defaults
if !matchesAudience(claims.Audience, audiences) {
return nil, errs.Unauthorized(errors.Errorf("x5c.authorizeToken; x5c token has invalid audience "+
"claim (aud); expected %s, but got %s", audiences, claims.Audience))
return nil, errs.Unauthorized("x5c.authorizeToken; x5c token has invalid audience "+
"claim (aud); expected %s, but got %s", audiences, claims.Audience)
}
if claims.Subject == "" {
return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; x5c token subject cannot be empty"))
return nil, errs.Unauthorized("x5c.authorizeToken; x5c token subject cannot be empty")
}
// Save the verified chains on the x5c payload object.
@ -213,7 +213,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized(errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID()))
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID())
}
return nil
}
@ -221,7 +221,7 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized(errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID()))
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID())
}
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
@ -230,13 +230,13 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
}
if claims.Step == nil || claims.Step.SSH == nil {
return nil, errs.Unauthorized(errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"))
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token")
}
opts := claims.Step.SSH
signOptions := []SignOption{
// validates user's SSHOptions with the ones in the token
sshCertificateOptionsValidator(*opts),
sshCertOptionsValidator(*opts),
}
// Add modifiers from custom claims
@ -272,8 +272,8 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate public key.
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertificateValidityValidator{p.claimer},
&sshCertValidityValidator{p.claimer},
// Require all the fields in the SSH certificate
&sshCertificateDefaultValidator{},
&sshCertDefaultValidator{},
), nil
}

View file

@ -548,7 +548,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil {
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -594,7 +594,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.TODO(), nil); err != nil {
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -754,7 +754,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.TODO(), tc.token); err != nil {
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -768,7 +768,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
nw := now()
for _, o := range opts {
switch v := o.(type) {
case sshCertificateOptionsValidator:
case sshCertOptionsValidator:
tc.claims.Step.SSH.ValidAfter.t = time.Time{}
tc.claims.Step.SSH.ValidBefore.t = time.Time{}
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
@ -787,10 +787,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
case *sshCertificateValidityValidator:
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator,
*sshCertificateDefaultValidator:
*sshCertDefaultValidator:
case sshCertKeyIDValidator:
assert.Equals(t, string(v), "foo")
default:

View file

@ -2,18 +2,16 @@ package authority
import (
"crypto/x509"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
key, ok := a.provisioners.LoadEncryptedKey(kid)
if !ok {
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
http.StatusNotFound, apiCtx{}}
return "", errs.NotFound("encrypted key with kid %s was not found", kid)
}
return key, nil
}
@ -30,8 +28,7 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
p, ok := a.provisioners.LoadByCertificate(crt)
if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"),
http.StatusNotFound, apiCtx{}}
return nil, errs.NotFound("provisioner not found")
}
return p, nil
}
@ -40,8 +37,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
p, ok := a.provisioners.Load(id)
if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"),
http.StatusNotFound, apiCtx{}}
return nil, errs.NotFound("provisioner not found")
}
return p, nil
}

View file

@ -7,13 +7,15 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
func TestGetEncryptedKey(t *testing.T) {
type ek struct {
a *Authority
kid string
err *apiError
err error
code int
}
tests := map[string]func(t *testing.T) *ek{
"ok": func(t *testing.T) *ek {
@ -34,8 +36,8 @@ func TestGetEncryptedKey(t *testing.T) {
return &ek{
a: a,
kid: "foo",
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
http.StatusNotFound, apiCtx{}},
err: errors.New("encrypted key with kid foo was not found"),
code: http.StatusNotFound,
}
},
}
@ -47,14 +49,10 @@ func TestGetEncryptedKey(t *testing.T) {
ek, err := tc.a.GetEncryptedKey(tc.kid)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {

View file

@ -2,23 +2,20 @@ package authority
import (
"crypto/x509"
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
)
// Root returns the certificate corresponding to the given SHA sum argument.
func (a *Authority) Root(sum string) (*x509.Certificate, error) {
val, ok := a.certificates.Load(sum)
if !ok {
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum),
http.StatusNotFound, apiCtx{}}
return nil, errs.NotFound("certificate with fingerprint %s was not found", sum)
}
crt, ok := val.(*x509.Certificate)
if !ok {
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, apiCtx{}}
return nil, errs.InternalServer("stored value is not a *x509.Certificate")
}
return crt, nil
}
@ -52,8 +49,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error)
crt, ok := v.(*x509.Certificate)
if !ok {
federation = nil
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, apiCtx{}}
err = errs.InternalServer("stored value is not a *x509.Certificate")
return false
}
federation = append(federation, crt)

View file

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/pemutil"
)
@ -17,11 +18,12 @@ func TestRoot(t *testing.T) {
tests := map[string]struct {
sum string
err *apiError
err error
code int
}{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
"not-found": {"foo", errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound},
"invalid-stored-certificate": {"invaliddata", errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil, http.StatusOK},
}
for name, tc := range tests {
@ -29,14 +31,10 @@ func TestRoot(t *testing.T) {
crt, err := a.Root(tc.sum)
if err != nil {
if assert.NotNil(t, tc.err) {
switch v := err.(type) {
case *apiError:
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
assert.Equals(t, v.code, tc.err.code)
assert.Equals(t, v.context, tc.err.context)
default:
t.Errorf("unexpected error type: %T", v)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {

View file

@ -122,7 +122,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
// GetSSHConfig returns rendered templates for clients (user) or servers (host).
func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
return nil, errs.NotFound(errors.New("getSSHConfig: ssh is not configured"))
return nil, errs.NotFound("getSSHConfig: ssh is not configured")
}
var ts []templates.Template
@ -136,7 +136,7 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
ts = a.config.Templates.SSH.Host
}
default:
return nil, errs.BadRequest(errors.Errorf("getSSHConfig: type %s is not valid", typ))
return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ)
}
// Merge user and default data
@ -177,13 +177,13 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error
}
return nil, nil
}
return nil, errs.NotFound(errors.New("authority.GetSSHBastion; ssh is not configured"))
return nil, errs.NotFound("authority.GetSSHBastion; ssh is not configured")
}
// SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var mods []provisioner.SSHCertificateModifier
var validators []provisioner.SSHCertificateValidator
var mods []provisioner.SSHCertModifier
var validators []provisioner.SSHCertValidator
// Set backdate with the configured value
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
@ -191,27 +191,27 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
for _, op := range signOpts {
switch o := op.(type) {
// modify the ssh.Certificate
case provisioner.SSHCertificateModifier:
case provisioner.SSHCertModifier:
mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions
case provisioner.SSHCertificateOptionModifier:
case provisioner.SSHCertOptionModifier:
mods = append(mods, o.Option(opts))
// validate the ssh.Certificate
case provisioner.SSHCertificateValidator:
case provisioner.SSHCertValidator:
validators = append(validators, o)
// validate the given SSHOptions
case provisioner.SSHCertificateOptionsValidator:
case provisioner.SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil {
return nil, errs.Forbidden(err)
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
}
default:
return nil, errs.InternalServerError(errors.Errorf("signSSH: invalid extra option type %T", o))
return nil, errs.InternalServer("signSSH: invalid extra option type %T", o)
}
}
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, errs.InternalServerError(err)
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH")
}
var serial uint64
@ -228,13 +228,13 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
// Use opts to modify the certificate
if err := opts.Modify(cert); err != nil {
return nil, errs.Forbidden(err)
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
}
// Use provisioner modifiers
for _, m := range mods {
if err := m.Modify(cert); err != nil {
return nil, errs.Forbidden(err)
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
}
}
@ -243,16 +243,16 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
switch cert.CertType {
case ssh.UserCert:
if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("signSSH: user certificate signing is not enabled"))
return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled")
}
signer = a.sshCAUserCertSignKey
case ssh.HostCert:
if a.sshCAHostCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("signSSH: host certificate signing is not enabled"))
return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled")
}
signer = a.sshCAHostCertSignKey
default:
return nil, errs.InternalServerError(errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType))
return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType)
}
cert.SignatureKey = signer.PublicKey()
@ -270,7 +270,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
// User provisioners validators
for _, v := range validators {
if err := v.Valid(cert); err != nil {
return nil, errs.Forbidden(err)
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
}
}
@ -285,7 +285,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) {
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, errs.InternalServerError(err)
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
}
var serial uint64
@ -294,7 +294,7 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
}
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest(errors.New("rewnewSSH: cannot renew certificate without validity period"))
return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period")
}
backdate := a.config.AuthorityConfig.Backdate.Duration
@ -321,16 +321,16 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
switch cert.CertType {
case ssh.UserCert:
if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("renewSSH: user certificate signing is not enabled"))
return nil, errs.NotImplemented("renewSSH: user certificate signing is not enabled")
}
signer = a.sshCAUserCertSignKey
case ssh.HostCert:
if a.sshCAHostCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("renewSSH: host certificate signing is not enabled"))
return nil, errs.NotImplemented("renewSSH: host certificate signing is not enabled")
}
signer = a.sshCAHostCertSignKey
default:
return nil, errs.InternalServerError(errors.Errorf("renewSSH: unexpected ssh certificate type: %d", cert.CertType))
return nil, errs.InternalServer("renewSSH: unexpected ssh certificate type: %d", cert.CertType)
}
cert.SignatureKey = signer.PublicKey()
@ -354,21 +354,21 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
// RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var validators []provisioner.SSHCertificateValidator
var validators []provisioner.SSHCertValidator
for _, op := range signOpts {
switch o := op.(type) {
// validate the ssh.Certificate
case provisioner.SSHCertificateValidator:
case provisioner.SSHCertValidator:
validators = append(validators, o)
default:
return nil, errs.InternalServerError(errors.Errorf("rekeySSH; invalid extra option type %T", o))
return nil, errs.InternalServer("rekeySSH; invalid extra option type %T", o)
}
}
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, errs.InternalServerError(err)
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH")
}
var serial uint64
@ -377,7 +377,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
}
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest(errors.New("rekeySSH; cannot rekey certificate without validity period"))
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
}
backdate := a.config.AuthorityConfig.Backdate.Duration
@ -404,16 +404,16 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
switch cert.CertType {
case ssh.UserCert:
if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("rekeySSH; user certificate signing is not enabled"))
return nil, errs.NotImplemented("rekeySSH; user certificate signing is not enabled")
}
signer = a.sshCAUserCertSignKey
case ssh.HostCert:
if a.sshCAHostCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("rekeySSH; host certificate signing is not enabled"))
return nil, errs.NotImplemented("rekeySSH; host certificate signing is not enabled")
}
signer = a.sshCAHostCertSignKey
default:
return nil, errs.BadRequest(errors.Errorf("rekeySSH; unexpected ssh certificate type: %d", cert.CertType))
return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)
}
cert.SignatureKey = signer.PublicKey()
@ -431,7 +431,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
// Apply validators from provisioner..
for _, v := range validators {
if err := v.Valid(cert); err != nil {
return nil, errs.Forbidden(err)
return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH")
}
}
@ -445,18 +445,18 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
// SignSSHAddUser signs a certificate that provisions a new user in a server.
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented(errors.New("signSSHAddUser: user certificate signing is not enabled"))
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
}
if subject.CertType != ssh.UserCert {
return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate is not a user certificate"))
return nil, errs.Forbidden("signSSHAddUser: certificate is not a user certificate")
}
if len(subject.ValidPrincipals) != 1 {
return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate does not have only one principal"))
return nil, errs.Forbidden("signSSHAddUser: certificate does not have only one principal")
}
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, errs.InternalServerError(err)
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser")
}
var serial uint64

View file

@ -80,7 +80,7 @@ func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
type sshTestOptionsModifier string
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier {
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertModifier {
return sshTestCertModifier(string(m))
}
@ -492,12 +492,12 @@ func TestAuthority_CheckSSHHost(t *testing.T) {
want bool
wantErr bool
}{
{"true", fields{true, nil}, args{context.TODO(), "foo.internal.com", ""}, true, false},
{"false", fields{false, nil}, args{context.TODO(), "foo.internal.com", ""}, false, false},
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true},
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true},
{"internal", fields{false, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true},
{"internal", fields{true, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true},
{"true", fields{true, nil}, args{context.Background(), "foo.internal.com", ""}, true, false},
{"false", fields{false, nil}, args{context.Background(), "foo.internal.com", ""}, false, false},
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"internal", fields{false, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
{"internal", fields{true, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -61,7 +61,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
// Sign creates a signed certificate from a certificate signing request.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
var (
opts = []errs.Option{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
opts = []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
certValidators = []provisioner.CertificateValidator{}
issIdentity = a.intermediateIdentity
@ -81,7 +81,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
case provisioner.ProfileModifier:
mods = append(mods, k.Option(signOpts))
default:
return nil, errs.InternalServerError(errors.Errorf("authority.Sign; invalid extra option type %T", k), opts...)
return nil, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]interface{}{k}, opts...)...)
}
}
@ -131,7 +131,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
// Renew creates a new Certificate identical to the old certificate, except
// with a validity window that begins 'now'.
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
opts := []errs.Option{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
// Check step provisioner extensions
if err := a.authorizeRenew(oldCert); err != nil {
@ -237,7 +237,7 @@ type RevokeOptions struct {
//
// TODO: Add OCSP and CRL support.
func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error {
opts := []errs.Option{
opts := []interface{}{
errs.WithKeyVal("serialNumber", revokeOpts.Serial),
errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode),
errs.WithKeyVal("reason", revokeOpts.Reason),
@ -281,7 +281,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
var ok bool
p, ok = a.provisioners.LoadByToken(token, &claims.Claims)
if !ok {
return errs.InternalServerError(errors.Errorf("authority.Revoke; provisioner not found"), opts...)
return errs.InternalServer("authority.Revoke; provisioner not found", opts...)
}
rci.TokenID, err = p.GetTokenID(revokeOpts.OTT)
if err != nil {
@ -309,10 +309,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
case nil:
return nil
case db.ErrNotImplemented:
return errs.NotImplemented(errors.New("authority.Revoke; no persistence layer configured"), opts...)
return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...)
case db.ErrAlreadyExists:
return errs.BadRequest(errors.Errorf("authority.Revoke; certificate with serial "+
"number %s has already been revoked", rci.Serial), opts...)
return errs.BadRequest("authority.Revoke; certificate with serial "+
"number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...)
default:
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
}

View file

@ -553,7 +553,7 @@ retry:
// verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw)
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) {
return nil, errs.BadRequest(errors.New("client.Root; root certificate SHA256 fingerprint do not match"))
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
}
return &root, nil
}
@ -961,8 +961,8 @@ func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrin
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", u,
errs.WithMessage("Failed to perform POST request to %s", u))
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed",
[]interface{}{u, errs.WithMessage("Failed to perform POST request to %s", u)}...)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
@ -974,8 +974,8 @@ retry:
}
var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", u,
errs.WithMessage("Failed to parse response from /ssh/check-host endpoint"))
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response",
[]interface{}{u, errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")})
}
return &check, nil
}

View file

@ -163,8 +163,8 @@ func TestClient_Version(t *testing.T) {
expectedErr error
}{
{"ok", ok, 200, false, nil},
{"500", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
{"404", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
{"500", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
{"404", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -214,7 +214,7 @@ func TestClient_Health(t *testing.T) {
expectedErr error
}{
{"ok", ok, 200, false, nil},
{"not ok", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
{"not ok", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -268,7 +268,7 @@ func TestClient_Root(t *testing.T) {
expectedErr error
}{
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil},
{"not found", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
{"not found", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -336,9 +336,9 @@ func TestClient_Sign(t *testing.T) {
expectedErr error
}{
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.SignRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -409,8 +409,8 @@ func TestClient_Revoke(t *testing.T) {
expectedErr error
}{
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -483,9 +483,9 @@ func TestClient_Renew(t *testing.T) {
err error
}{
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -533,7 +533,7 @@ func TestClient_Provisioners(t *testing.T) {
ok := &api.ProvisionersResponse{
Provisioners: provisioner.List{},
}
internalServerError := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
internalServerError := errs.InternalServer("Internal Server Error")
tests := []struct {
name string
@ -603,7 +603,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
err error
}{
{"ok", "kid", ok, 200, false, nil},
{"fail", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
{"fail", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -665,8 +665,8 @@ func TestClient_Roots(t *testing.T) {
err error
}{
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"bad-request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -724,7 +724,7 @@ func TestClient_Federation(t *testing.T) {
err error
}{
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -786,7 +786,7 @@ func TestClient_SSHRoots(t *testing.T) {
err error
}{
{"ok", ok, 200, false, nil},
{"not found", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
{"not found", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -869,7 +869,7 @@ func Test_parseEndpoint(t *testing.T) {
func TestClient_RootFingerprint(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"}
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
nok := errs.InternalServer("Internal Server Error")
httpsServer := httptest.NewTLSServer(nil)
defer httpsServer.Close()
@ -947,7 +947,7 @@ func TestClient_SSHBastion(t *testing.T) {
}{
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
{"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)

View file

@ -62,31 +62,6 @@ type Error struct {
Details map[string]interface{}
}
// New returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func New(status int, err error, opts ...Option) error {
var (
e *Error
ok bool
)
if e, ok = err.(*Error); !ok {
if sc, ok := err.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
e = &Error{Status: status, Err: err}
}
}
}
for _, o := range opts {
o(e)
}
return e
}
// ErrorResponse represents an error in JSON format.
type ErrorResponse struct {
Status int `json:"status"`
@ -119,10 +94,11 @@ func (e *Error) Message() string {
// Wrap returns an error annotating err with a stack trace at the point Wrap is
// called, and the supplied message. If err is nil, Wrap returns nil.
func Wrap(status int, e error, m string, opts ...Option) error {
func Wrap(status int, e error, m string, args ...interface{}) error {
if e == nil {
return nil
}
_, opts := splitOptionArgs(args)
if err, ok := e.(*Error); ok {
err.Err = errors.Wrap(err.Err, m)
e = err
@ -138,25 +114,12 @@ func Wrapf(status int, e error, format string, args ...interface{}) error {
if e == nil {
return nil
}
var opts []Option
for i, arg := range args {
// Once we find the first Option, assume that all further arguments are Options.
if _, ok := arg.(Option); ok {
for _, a := range args[i:] {
// Ignore any arguments after the first Option that are not Options.
if opt, ok := a.(Option); ok {
opts = append(opts, opt)
}
}
args = args[:i]
break
}
}
as, opts := splitOptionArgs(args)
if err, ok := e.(*Error); ok {
err.Err = errors.Wrapf(err.Err, format, args...)
e = err
} else {
e = errors.Wrapf(e, format, args...)
e = errors.Wrapf(e, format, as...)
}
return StatusCodeError(status, e, opts...)
}
@ -201,24 +164,24 @@ type Messenger interface {
func StatusCodeError(code int, e error, opts ...Option) error {
switch code {
case http.StatusBadRequest:
return BadRequest(e, opts...)
return BadRequestErr(e, opts...)
case http.StatusUnauthorized:
return Unauthorized(e, opts...)
return UnauthorizedErr(e, opts...)
case http.StatusForbidden:
return Forbidden(e, opts...)
return ForbiddenErr(e, opts...)
case http.StatusInternalServerError:
return InternalServerError(e, opts...)
return InternalServerErr(e, opts...)
case http.StatusNotImplemented:
return NotImplemented(e, opts...)
return NotImplementedErr(e, opts...)
default:
return UnexpectedError(code, e, opts...)
return UnexpectedErr(code, e, opts...)
}
}
var (
seeLogs = "Please see the certificate authority logs for more info."
// BadRequestDefaultMsg 400 default msg
BadRequestDefaultMsg = "The request could not be completed due to being poorly formatted or missing critical data. " + seeLogs
BadRequestDefaultMsg = "The request could not be completed; malformed or missing data" + seeLogs
// UnauthorizedDefaultMsg 401 default msg
UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs
// ForbiddenDefaultMsg 403 default msg
@ -231,46 +194,142 @@ var (
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
)
// InternalServerError returns a 500 error with the given error.
func InternalServerError(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg))
return New(http.StatusInternalServerError, err, opts...)
// splitOptionArgs splits the variadic length args into string formatting args
// and Option(s) to apply to an Error.
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
indexOptionStart := -1
for i, a := range args {
if _, ok := a.(Option); ok {
indexOptionStart = i
break
}
}
// NotImplemented returns a 501 error with the given error.
func NotImplemented(err error, opts ...Option) error {
if indexOptionStart < 0 {
return args, []Option{}
}
opts := []Option{}
// Ignore any non-Option args that come after the first Option.
for _, o := range args[indexOptionStart:] {
if opt, ok := o.(Option); ok {
opts = append(opts, opt)
}
}
return args[:indexOptionStart], opts
}
// NewErr returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func NewErr(status int, err error, opts ...Option) error {
var (
e *Error
ok bool
)
if e, ok = err.(*Error); !ok {
if sc, ok := err.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
e = &Error{Status: status, Err: err}
}
}
}
for _, o := range opts {
o(e)
}
return e
}
// Errorf creates a new error using the given format and status code.
func Errorf(code int, format string, args ...interface{}) error {
as, opts := splitOptionArgs(args)
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
return New(http.StatusNotImplemented, err, opts...)
e := &Error{Status: code, Err: fmt.Errorf(format, as...)}
for _, o := range opts {
o(e)
}
return e
}
// BadRequest returns an 400 error with the given error.
func BadRequest(err error, opts ...Option) error {
// InternalServer creates a 500 error with the given format and arguments.
func InternalServer(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg))
return Errorf(http.StatusInternalServerError, format, args...)
}
// InternalServerErr returns a 500 error with the given error.
func InternalServerErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg))
return NewErr(http.StatusInternalServerError, err, opts...)
}
// NotImplemented creates a 501 error with the given format and arguments.
func NotImplemented(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(NotImplementedDefaultMsg))
return Errorf(http.StatusNotImplemented, format, args...)
}
// NotImplementedErr returns a 501 error with the given error.
func NotImplementedErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
return NewErr(http.StatusNotImplemented, err, opts...)
}
// BadRequest creates a 400 error with the given format and arguments.
func BadRequest(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(BadRequestDefaultMsg))
return Errorf(http.StatusBadRequest, format, args...)
}
// BadRequestErr returns an 400 error with the given error.
func BadRequestErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return New(http.StatusBadRequest, err, opts...)
return NewErr(http.StatusBadRequest, err, opts...)
}
// Unauthorized returns an 401 error with the given error.
func Unauthorized(err error, opts ...Option) error {
// Unauthorized creates a 401 error with the given format and arguments.
func Unauthorized(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(UnauthorizedDefaultMsg))
return Errorf(http.StatusUnauthorized, format, args...)
}
// UnauthorizedErr returns an 401 error with the given error.
func UnauthorizedErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg))
return New(http.StatusUnauthorized, err, opts...)
return NewErr(http.StatusUnauthorized, err, opts...)
}
// Forbidden returns an 403 error with the given error.
func Forbidden(err error, opts ...Option) error {
// Forbidden creates a 403 error with the given format and arguments.
func Forbidden(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(ForbiddenDefaultMsg))
return Errorf(http.StatusForbidden, format, args...)
}
// ForbiddenErr returns an 403 error with the given error.
func ForbiddenErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg))
return New(http.StatusForbidden, err, opts...)
return NewErr(http.StatusForbidden, err, opts...)
}
// NotFound returns an 404 error with the given error.
func NotFound(err error, opts ...Option) error {
// NotFound creates a 404 error with the given format and arguments.
func NotFound(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(NotFoundDefaultMsg))
return Errorf(http.StatusNotFound, format, args...)
}
// NotFoundErr returns an 404 error with the given error.
func NotFoundErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(NotFoundDefaultMsg))
return New(http.StatusNotFound, err, opts...)
return NewErr(http.StatusNotFound, err, opts...)
}
// UnexpectedError will be used when the certificate authority makes an outgoing
// UnexpectedErr will be used when the certificate authority makes an outgoing
// request and receives an unhandled status code.
func UnexpectedError(code int, err error, opts ...Option) error {
func UnexpectedErr(code int, err error, opts ...Option) error {
opts = append(opts, withDefaultMessage("The certificate authority received an "+
"unexpected HTTP status code - '%d'. "+seeLogs, code))
return New(code, err, opts...)
return NewErr(code, err, opts...)
}