forked from TrueCloudLab/certificates
Simplify statuscoder error generators.
This commit is contained in:
parent
dccbdf3a90
commit
1cb8bb3ae1
45 changed files with 483 additions and 441 deletions
|
@ -63,6 +63,7 @@ issues:
|
|||
- declaration of "err" shadows declaration at line
|
||||
- should have a package comment, unless it's in another file for this package
|
||||
- error strings should not be capitalized or end with punctuation or a newline
|
||||
- Wrapf call needs 1 arg but has 2 args
|
||||
# golangci.com configuration
|
||||
# https://github.com/golangci/golangci/wiki/Configuration
|
||||
service:
|
||||
|
|
12
api/api.go
12
api/api.go
|
@ -295,7 +295,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
|||
// Load root certificate with the
|
||||
cert, err := h.Authority.Root(sum)
|
||||
if err != nil {
|
||||
WriteError(w, errs.NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI)))
|
||||
WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -314,13 +314,13 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
|||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := parseCursor(r)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := h.Authority.GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
JSON(w, &ProvisionersResponse{
|
||||
|
@ -334,7 +334,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
|||
kid := chi.URLParam(r, "kid")
|
||||
key, err := h.Authority.GetEncryptedKey(kid)
|
||||
if err != nil {
|
||||
WriteError(w, errs.NotFound(err))
|
||||
WriteError(w, errs.NotFoundErr(err))
|
||||
return
|
||||
}
|
||||
JSON(w, &ProvisionerKeyResponse{key})
|
||||
|
@ -344,7 +344,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := h.Authority.GetRoots()
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -362,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := h.Authority.GetFederation()
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -915,7 +915,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
|
||||
{"renew error", cs, nil, nil, errs.Forbidden(fmt.Errorf("an error")), http.StatusForbidden},
|
||||
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
||||
}
|
||||
|
||||
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
|
@ -1010,10 +1010,10 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedError400 := errs.BadRequest(errors.New("force"))
|
||||
expectedError400 := errs.BadRequest("force")
|
||||
expectedError400Bytes, err := json.Marshal(expectedError400)
|
||||
assert.FatalError(t, err)
|
||||
expectedError500 := errs.InternalServerError(errors.New("force"))
|
||||
expectedError500 := errs.InternalServer("force")
|
||||
expectedError500Bytes, err := json.Marshal(expectedError500)
|
||||
assert.FatalError(t, err)
|
||||
for _, tt := range tests {
|
||||
|
@ -1082,7 +1082,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
|||
}
|
||||
|
||||
expected := []byte(`{"key":"` + privKey + `"}`)
|
||||
expectedError404 := errs.NotFound(errors.New("force"))
|
||||
expectedError404 := errs.NotFound("force")
|
||||
expectedError404Bytes, err := json.Marshal(expectedError404)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ package api
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
|
@ -11,7 +10,7 @@ import (
|
|||
// new one.
|
||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
WriteError(w, errs.BadRequest(errors.New("missing peer certificate")))
|
||||
WriteError(w, errs.BadRequest("missing peer certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -22,7 +21,7 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 0 {
|
||||
if len(certChainPEM) > 1 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
@ -30,13 +29,13 @@ type RevokeRequest struct {
|
|||
// or an error if something is wrong.
|
||||
func (r *RevokeRequest) Validate() (err error) {
|
||||
if r.Serial == "" {
|
||||
return errs.BadRequest(errors.New("missing serial"))
|
||||
return errs.BadRequest("missing serial")
|
||||
}
|
||||
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
|
||||
return errs.BadRequest(errors.New("reasonCode out of bounds"))
|
||||
return errs.BadRequest("reasonCode out of bounds")
|
||||
}
|
||||
if !r.Passive {
|
||||
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
|
||||
return errs.NotImplemented("non-passive revocation not implemented")
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -50,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) {
|
|||
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body RevokeRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -72,7 +71,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
if len(body.OTT) > 0 {
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
@ -81,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
// the client certificate Serial Number must match the serial number
|
||||
// being revoked.
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
WriteError(w, errs.BadRequest(errors.New("missing ott or peer certificate")))
|
||||
WriteError(w, errs.BadRequest("missing ott or peer certificate"))
|
||||
return
|
||||
}
|
||||
opts.Crt = r.TLS.PeerCertificates[0]
|
||||
if opts.Crt.SerialNumber.String() != opts.Serial {
|
||||
WriteError(w, errs.BadRequest(errors.New("revoke: serial number in mtls certificate different than body")))
|
||||
WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body"))
|
||||
return
|
||||
}
|
||||
// TODO: should probably be checking if the certificate was revoked here.
|
||||
|
@ -97,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
return errs.InternalServerError(errors.New("force"))
|
||||
return errs.InternalServer("force")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
15
api/sign.go
15
api/sign.go
|
@ -4,7 +4,6 @@ import (
|
|||
"crypto/tls"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
|
@ -22,13 +21,13 @@ type SignRequest struct {
|
|||
// or an error if something is wrong.
|
||||
func (s *SignRequest) Validate() error {
|
||||
if s.CsrPEM.CertificateRequest == nil {
|
||||
return errs.BadRequest(errors.New("missing csr"))
|
||||
return errs.BadRequest("missing csr")
|
||||
}
|
||||
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
|
||||
return errs.BadRequest(errors.Wrap(err, "invalid csr"))
|
||||
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
|
||||
}
|
||||
if s.OTT == "" {
|
||||
return errs.BadRequest(errors.New("missing ott"))
|
||||
return errs.BadRequest("missing ott")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -49,7 +48,7 @@ type SignResponse struct {
|
|||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -66,18 +65,18 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 0 {
|
||||
if len(certChainPEM) > 1 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
logCertificate(w, certChain[0])
|
||||
|
|
46
api/ssh.go
46
api/ssh.go
|
@ -249,19 +249,19 @@ type SSHBastionResponse struct {
|
|||
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHSignRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -269,7 +269,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
if body.AddUserPublicKey != nil {
|
||||
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -285,13 +285,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -299,7 +299,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
|
||||
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
addUserCertificate = &SSHCertificate{addUserCert}
|
||||
|
@ -320,12 +320,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
certChain, err := h.Authority.Sign(cr, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
identityCertificate = certChainToPEM(certChain)
|
||||
|
@ -343,12 +343,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHRoots()
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
WriteError(w, errs.NotFound(errors.New("no keys found")))
|
||||
WriteError(w, errs.NotFound("no keys found"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -368,12 +368,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHFederation()
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
WriteError(w, errs.NotFound(errors.New("no keys found")))
|
||||
WriteError(w, errs.NotFound("no keys found"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -393,17 +393,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHConfigRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -414,7 +414,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
case provisioner.SSHHostCert:
|
||||
config.HostTemplates = ts
|
||||
default:
|
||||
WriteError(w, errs.InternalServerError(errors.New("it should hot get here")))
|
||||
WriteError(w, errs.InternalServer("it should hot get here"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -429,13 +429,13 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
JSON(w, &SSHCheckPrincipalResponse{
|
||||
|
@ -452,7 +452,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
hosts, err := h.Authority.GetSSHHosts(cert)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
JSON(w, &SSHGetHostsResponse{
|
||||
|
@ -464,17 +464,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHBastionRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -40,42 +40,42 @@ type SSHRekeyResponse struct {
|
|||
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRekeyRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
}
|
||||
|
||||
newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -36,36 +36,36 @@ type SSHRenewResponse struct {
|
|||
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRenewRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequest(err))
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod)
|
||||
_, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, errs.InternalServerError(err))
|
||||
WriteError(w, errs.InternalServerErr(err))
|
||||
}
|
||||
|
||||
newCert, err := h.Authority.RenewSSH(oldCert)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
@ -30,16 +29,16 @@ type SSHRevokeRequest struct {
|
|||
// or an error if something is wrong.
|
||||
func (r *SSHRevokeRequest) Validate() (err error) {
|
||||
if r.Serial == "" {
|
||||
return errs.BadRequest(errors.New("missing serial"))
|
||||
return errs.BadRequest("missing serial")
|
||||
}
|
||||
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
|
||||
return errs.BadRequest(errors.New("reasonCode out of bounds"))
|
||||
return errs.BadRequest("reasonCode out of bounds")
|
||||
}
|
||||
if !r.Passive {
|
||||
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
|
||||
return errs.NotImplemented("non-passive revocation not implemented")
|
||||
}
|
||||
if len(r.OTT) == 0 {
|
||||
return errs.BadRequest(errors.New("missing ott"))
|
||||
return errs.BadRequest("missing ott")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -50,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
|||
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRevokeRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -71,13 +70,13 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
|||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
WriteError(w, errs.Unauthorized(err))
|
||||
WriteError(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
WriteError(w, errs.Forbidden(err))
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
@ -69,7 +68,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
|||
// pointed by v.
|
||||
func ReadJSON(r io.Reader, v interface{}) error {
|
||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||
return errs.BadRequest(errors.Wrap(err, "error decoding json"))
|
||||
return errs.Wrap(http.StatusBadRequest, err, "error decoding json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/jose"
|
||||
|
@ -58,15 +57,15 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
|||
// This check is meant as a stopgap solution to the current lack of a persistence layer.
|
||||
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
|
||||
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
|
||||
return nil, errs.Unauthorized(errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"))
|
||||
return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority")
|
||||
}
|
||||
}
|
||||
|
||||
// This method will also validate the audiences for JWK provisioners.
|
||||
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
|
||||
if !ok {
|
||||
return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: provisioner "+
|
||||
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")))
|
||||
return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
|
||||
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
|
||||
}
|
||||
|
||||
// Store the token to protect against reuse unless it's skipped.
|
||||
|
@ -78,7 +77,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
|||
"authority.authorizeToken: failed when attempting to store token")
|
||||
}
|
||||
if !ok {
|
||||
return nil, errs.Unauthorized(errors.Errorf("authority.authorizeToken: token already used"))
|
||||
return nil, errs.Unauthorized("authority.authorizeToken: token already used")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -89,7 +88,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
|||
// Authorize grabs the method from the context and authorizes the request by
|
||||
// validating the one-time-token.
|
||||
func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
var opts = []errs.Option{errs.WithKeyVal("token", token)}
|
||||
var opts = []interface{}{errs.WithKeyVal("token", token)}
|
||||
|
||||
switch m := provisioner.MethodFromContext(ctx); m {
|
||||
case provisioner.SignMethod:
|
||||
|
@ -99,13 +98,13 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.
|
|||
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...)
|
||||
case provisioner.SSHSignMethod:
|
||||
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...)
|
||||
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
|
||||
}
|
||||
_, err := a.authorizeSSHSign(ctx, token)
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
|
||||
case provisioner.SSHRenewMethod:
|
||||
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...)
|
||||
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
|
||||
}
|
||||
_, err := a.authorizeSSHRenew(ctx, token)
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
|
||||
|
@ -113,12 +112,12 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.
|
|||
return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...)
|
||||
case provisioner.SSHRekeyMethod:
|
||||
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("authority.Authorize; ssh certificate flows are not enabled"), opts...)
|
||||
return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...)
|
||||
}
|
||||
_, signOpts, err := a.authorizeSSHRekey(ctx, token)
|
||||
return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
|
||||
default:
|
||||
return nil, errs.InternalServerError(errors.Errorf("authority.Authorize; method %d is not supported", m), opts...)
|
||||
return nil, errs.InternalServer("authority.Authorize; method %d is not supported", append([]interface{}{m}, opts...)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -165,7 +164,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
|||
//
|
||||
// TODO(mariano): should we authorize by default?
|
||||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||
var opts = []errs.Option{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
|
||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
|
||||
|
||||
// Check the passive revocation table.
|
||||
isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String())
|
||||
|
@ -173,12 +172,12 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
|||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||
}
|
||||
if isRevoked {
|
||||
return errs.Unauthorized(errors.New("authority.authorizeRenew: certificate has been revoked"), opts...)
|
||||
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
|
||||
}
|
||||
|
||||
p, ok := a.provisioners.LoadByCertificate(cert)
|
||||
if !ok {
|
||||
return errs.Unauthorized(errors.New("authority.authorizeRenew: provisioner not found"), opts...)
|
||||
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
||||
}
|
||||
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||
|
|
|
@ -180,7 +180,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
}
|
||||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
_, err = _a.authorizeToken(context.TODO(), raw)
|
||||
_, err = _a.authorizeToken(context.Background(), raw)
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: _a,
|
||||
|
@ -268,7 +268,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
tc := genTestCase(t)
|
||||
|
||||
p, err := tc.auth.authorizeToken(context.TODO(), tc.token)
|
||||
p, err := tc.auth.authorizeToken(context.Background(), tc.token)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
|
@ -355,7 +355,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
tc := genTestCase(t)
|
||||
|
||||
if err := tc.auth.authorizeRevoke(context.TODO(), tc.token); err != nil {
|
||||
if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
|
|
|
@ -80,7 +80,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized(errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()))
|
||||
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -306,7 +306,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized(errors.Errorf("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID()))
|
||||
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -353,7 +353,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token")
|
||||
}
|
||||
if len(jwt.Headers) == 0 {
|
||||
return nil, errs.InternalServerError(errors.New("aws.authorizeToken; error parsing token, header is missing"))
|
||||
return nil, errs.InternalServer("aws.authorizeToken; error parsing token, header is missing")
|
||||
}
|
||||
|
||||
var unsafeClaims awsPayload
|
||||
|
@ -378,13 +378,13 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
|
||||
switch {
|
||||
case doc.AccountID == "":
|
||||
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document accountId cannot be empty"))
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document accountId cannot be empty")
|
||||
case doc.InstanceID == "":
|
||||
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty"))
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document instanceId cannot be empty")
|
||||
case doc.PrivateIP == "":
|
||||
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty"))
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document privateIp cannot be empty")
|
||||
case doc.Region == "":
|
||||
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document region cannot be empty"))
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document region cannot be empty")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
|
@ -399,7 +399,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(payload.Audience, p.audiences.Sign) {
|
||||
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)"))
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
// Validate subject, it has to be known if disableCustomSANs is enabled
|
||||
|
@ -407,7 +407,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
if payload.Subject != doc.InstanceID &&
|
||||
payload.Subject != doc.PrivateIP &&
|
||||
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
|
||||
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)"))
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -421,14 +421,14 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid"))
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid aws identity document - accountId is not valid")
|
||||
}
|
||||
}
|
||||
|
||||
// validate instance age
|
||||
if d := p.InstanceAge.Value(); d > 0 {
|
||||
if now.Sub(doc.PendingTime) > d {
|
||||
return nil, errs.Unauthorized(errors.New("aws.authorizeToken; aws identity document pendingTime is too old"))
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; aws identity document pendingTime is too old")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -439,7 +439,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized(errors.Errorf("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID()))
|
||||
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())
|
||||
}
|
||||
claims, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
|
@ -462,7 +462,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
},
|
||||
}
|
||||
// Validate user options
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
|
||||
// Set defaults if not given as user options
|
||||
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
|
||||
|
||||
|
@ -474,8 +474,8 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -704,7 +704,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.aws.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
|
|
|
@ -210,14 +210,14 @@ func (p *Azure) Init(config Config) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// authorizeToken returs the claims, name, group, error.
|
||||
// authorizeToken returns the claims, name, group, error.
|
||||
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
|
||||
}
|
||||
if len(jwt.Headers) == 0 {
|
||||
return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token missing header"))
|
||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
|
||||
}
|
||||
|
||||
var found bool
|
||||
|
@ -230,7 +230,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; cannot validate azure token"))
|
||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
|
||||
}
|
||||
|
||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||
|
@ -243,12 +243,12 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
|
|||
|
||||
// Validate TenantID
|
||||
if claims.TenantID != p.TenantID {
|
||||
return nil, "", "", errs.Unauthorized(errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)"))
|
||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
|
||||
}
|
||||
|
||||
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
|
||||
if len(re) != 4 {
|
||||
return nil, "", "", errs.Unauthorized(errors.Errorf("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID))
|
||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
|
||||
}
|
||||
group, name := re[2], re[3]
|
||||
return &claims, name, group, nil
|
||||
|
@ -272,7 +272,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errs.Unauthorized(errors.New("azure.AuthorizeSign; azure token validation failed - invalid resource group"))
|
||||
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid resource group")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -302,7 +302,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized(errors.Errorf("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID()))
|
||||
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -310,7 +310,7 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized(errors.Errorf("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID()))
|
||||
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
|
||||
_, name, _, err := p.authorizeToken(token)
|
||||
|
@ -328,7 +328,7 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
Principals: []string{name},
|
||||
}
|
||||
// Validate user options
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
|
||||
// Set defaults if not given as user options
|
||||
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
|
||||
|
||||
|
@ -340,9 +340,9 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -488,7 +488,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.azure.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
|
|
|
@ -243,7 +243,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized(errors.Errorf("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID()))
|
||||
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -264,7 +264,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token")
|
||||
}
|
||||
if len(jwt.Headers) == 0 {
|
||||
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; error parsing gcp token - header is missing"))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; error parsing gcp token - header is missing")
|
||||
}
|
||||
|
||||
var found bool
|
||||
|
@ -278,7 +278,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errs.Unauthorized(errors.Errorf("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid)
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
|
@ -293,7 +293,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, p.audiences.Sign) {
|
||||
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)"))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
// validate subject (service account)
|
||||
|
@ -306,7 +306,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim"))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid subject claim")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -320,26 +320,26 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; invalid gcp token - invalid project id"))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid project id")
|
||||
}
|
||||
}
|
||||
|
||||
// validate instance age
|
||||
if d := p.InstanceAge.Value(); d > 0 {
|
||||
if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d {
|
||||
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old"))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old")
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case claims.Google.ComputeEngine.InstanceID == "":
|
||||
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty"))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty")
|
||||
case claims.Google.ComputeEngine.InstanceName == "":
|
||||
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty"))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty")
|
||||
case claims.Google.ComputeEngine.ProjectID == "":
|
||||
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty"))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty")
|
||||
case claims.Google.ComputeEngine.Zone == "":
|
||||
return nil, errs.Unauthorized(errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty"))
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
|
@ -348,7 +348,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized(errors.Errorf("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID()))
|
||||
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())
|
||||
}
|
||||
claims, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
|
@ -371,7 +371,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
},
|
||||
}
|
||||
// Validate user options
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
|
||||
// Set defaults if not given as user options
|
||||
signOptions = append(signOptions, sshCertDefaultsModifier(defaults))
|
||||
|
||||
|
@ -383,8 +383,8 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -680,7 +680,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
|
|
|
@ -120,12 +120,12 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
|
|||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, audiences) {
|
||||
return nil, errs.Unauthorized(errors.Errorf("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s",
|
||||
audiences, claims.Audience))
|
||||
return nil, errs.Unauthorized("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s",
|
||||
audiences, claims.Audience)
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errs.Unauthorized(errors.New("jwk.authorizeToken; jwk token subject cannot be empty"))
|
||||
return nil, errs.Unauthorized("jwk.authorizeToken; jwk token subject cannot be empty")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
|
@ -173,7 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized(errors.Errorf("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID()))
|
||||
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -181,20 +181,20 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized(errors.Errorf("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID()))
|
||||
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())
|
||||
}
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
|
||||
}
|
||||
if claims.Step == nil || claims.Step.SSH == nil {
|
||||
return nil, errs.Unauthorized(errors.New("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token"))
|
||||
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token")
|
||||
}
|
||||
|
||||
opts := claims.Step.SSH
|
||||
signOptions := []SignOption{
|
||||
// validates user's SSHOptions with the ones in the token
|
||||
sshCertificateOptionsValidator(*opts),
|
||||
sshCertOptionsValidator(*opts),
|
||||
}
|
||||
|
||||
t := now()
|
||||
|
@ -231,9 +231,9 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -222,7 +222,7 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRevoke(context.TODO(), tt.args.token); err != nil {
|
||||
if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
|
@ -337,7 +337,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenew(context.TODO(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
|
|
|
@ -149,7 +149,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
|
|||
claims k8sSAPayload
|
||||
)
|
||||
if p.pubKeys == nil {
|
||||
return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented"))
|
||||
return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented")
|
||||
/* NOTE: We plan to support the TokenReview API in a future release.
|
||||
Below is some code that should be useful when we prioritize
|
||||
this integration.
|
||||
|
@ -177,7 +177,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
|
|||
}
|
||||
}
|
||||
if !valid {
|
||||
return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims"))
|
||||
return nil, errs.Unauthorized("k8ssa.authorizeToken; error validating k8sSA token and extracting claims")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
|
@ -189,7 +189,7 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
|
|||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errs.Unauthorized(errors.New("k8ssa.authorizeToken; k8sSA token subject cannot be empty"))
|
||||
return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA token subject cannot be empty")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
|
@ -221,7 +221,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()))
|
||||
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -229,7 +229,7 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
|
|||
// AuthorizeSSHSign validates an request for an SSH certificate.
|
||||
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized(errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()))
|
||||
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())
|
||||
}
|
||||
if _, err := p.authorizeToken(token, p.audiences.SSHSign); err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
|
||||
|
@ -246,9 +246,9 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -363,10 +363,10 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
|||
case sshCertDefaultsModifier:
|
||||
assert.Equals(t, v.CertType, SSHUserCert)
|
||||
case *sshDefaultExtensionModifier:
|
||||
case *sshCertificateValidityValidator:
|
||||
case *sshCertValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
case *sshDefaultPublicKeyValidator:
|
||||
case *sshCertificateDefaultValidator:
|
||||
case *sshCertDefaultValidator:
|
||||
case *sshDefaultDuration:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
default:
|
||||
|
|
|
@ -14,8 +14,8 @@ func Test_noop(t *testing.T) {
|
|||
assert.Equals(t, "noop", p.GetName())
|
||||
assert.Equals(t, noopType, p.GetType())
|
||||
assert.Equals(t, nil, p.Init(Config{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRenew(context.TODO(), &x509.Certificate{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRevoke(context.TODO(), "foo"))
|
||||
assert.Equals(t, nil, p.AuthorizeRenew(context.Background(), &x509.Certificate{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRevoke(context.Background(), "foo"))
|
||||
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
assert.Equals(t, "", kid)
|
||||
|
|
|
@ -195,12 +195,12 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
|||
|
||||
// Validate azp if present
|
||||
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
|
||||
return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: invalid azp"))
|
||||
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: invalid azp")
|
||||
}
|
||||
|
||||
// Enforce an email claim
|
||||
if p.Email == "" {
|
||||
return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email not found"))
|
||||
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email not found")
|
||||
}
|
||||
|
||||
// Validate domains (case-insensitive)
|
||||
|
@ -214,7 +214,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return errs.Unauthorized(errors.New("validatePayload: failed to validate oidc token payload: email is not allowed"))
|
||||
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email is not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -230,7 +230,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return errs.Unauthorized(errors.New("validatePayload: oidc token payload validation failed: invalid group"))
|
||||
return errs.Unauthorized("validatePayload: oidc token payload validation failed: invalid group")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -263,7 +263,7 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errs.Unauthorized(errors.New("oidc.AuthorizeToken; cannot validate oidc token"))
|
||||
return nil, errs.Unauthorized("oidc.AuthorizeToken; cannot validate oidc token")
|
||||
}
|
||||
|
||||
if err := o.ValidatePayload(claims); err != nil {
|
||||
|
@ -286,7 +286,7 @@ func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
|
|||
if o.IsAdmin(claims.Email) {
|
||||
return nil
|
||||
}
|
||||
return errs.Unauthorized(errors.New("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token"))
|
||||
return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
|
@ -318,7 +318,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// certificate was configured to allow renewals.
|
||||
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if o.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized(errors.Errorf("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID()))
|
||||
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -326,7 +326,7 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !o.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized(errors.Errorf("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID()))
|
||||
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())
|
||||
}
|
||||
claims, err := o.authorizeToken(token)
|
||||
if err != nil {
|
||||
|
@ -352,7 +352,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
|||
// Non-admin users can only use principals returned by the identityFunc, and
|
||||
// can only sign user certificates.
|
||||
if !o.IsAdmin(claims.Email) {
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
signOptions = append(signOptions, sshCertOptionsValidator(defaults))
|
||||
}
|
||||
|
||||
// Default to a user certificate with usernames as principals if those options
|
||||
|
@ -367,9 +367,9 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{o.claimer},
|
||||
&sshCertValidityValidator{o.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
|
@ -382,7 +382,7 @@ func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
|||
|
||||
// Only admins can revoke certificates.
|
||||
if !o.IsAdmin(claims.Email) {
|
||||
return errs.Unauthorized(errors.New("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token"))
|
||||
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -284,43 +284,43 @@ type base struct{}
|
|||
// AuthorizeSign returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for signing x509 Certificates.
|
||||
func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSign not implemented"))
|
||||
return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for revoking x509 Certificates.
|
||||
func (b *base) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
return errs.Unauthorized(errors.New("provisioner.AuthorizeRevoke not implemented"))
|
||||
return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeRenew returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for renewing x509 Certificates.
|
||||
func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
return errs.Unauthorized(errors.New("provisioner.AuthorizeRenew not implemented"))
|
||||
return errs.Unauthorized("provisioner.AuthorizeRenew not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeSSHSign returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for signing SSH Certificates.
|
||||
func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHSign not implemented"))
|
||||
return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for revoking SSH Certificates.
|
||||
func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
return errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRevoke not implemented"))
|
||||
return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeSSHRenew returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for renewing SSH Certificates.
|
||||
func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
|
||||
return nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRenew not implemented"))
|
||||
return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented")
|
||||
}
|
||||
|
||||
// AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite
|
||||
// this method if they will support authorizing tokens for rekeying SSH Certificates.
|
||||
func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
|
||||
return nil, nil, errs.Unauthorized(errors.New("provisioner.AuthorizeSSHRekey not implemented"))
|
||||
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
|
||||
}
|
||||
|
||||
// Identity is the type representing an externally supplied identity that is used
|
||||
|
|
|
@ -19,29 +19,29 @@ const (
|
|||
SSHHostCert = "host"
|
||||
)
|
||||
|
||||
// SSHCertificateModifier is the interface used to change properties in an SSH
|
||||
// SSHCertModifier is the interface used to change properties in an SSH
|
||||
// certificate.
|
||||
type SSHCertificateModifier interface {
|
||||
type SSHCertModifier interface {
|
||||
SignOption
|
||||
Modify(cert *ssh.Certificate) error
|
||||
}
|
||||
|
||||
// SSHCertificateOptionModifier is the interface used to add custom options used
|
||||
// SSHCertOptionModifier is the interface used to add custom options used
|
||||
// to modify the SSH certificate.
|
||||
type SSHCertificateOptionModifier interface {
|
||||
type SSHCertOptionModifier interface {
|
||||
SignOption
|
||||
Option(o SSHOptions) SSHCertificateModifier
|
||||
Option(o SSHOptions) SSHCertModifier
|
||||
}
|
||||
|
||||
// SSHCertificateValidator is the interface used to validate an SSH certificate.
|
||||
type SSHCertificateValidator interface {
|
||||
// SSHCertValidator is the interface used to validate an SSH certificate.
|
||||
type SSHCertValidator interface {
|
||||
SignOption
|
||||
Valid(cert *ssh.Certificate) error
|
||||
}
|
||||
|
||||
// SSHCertificateOptionsValidator is the interface used to validate the custom
|
||||
// SSHCertOptionsValidator is the interface used to validate the custom
|
||||
// options used to modify the SSH certificate.
|
||||
type SSHCertificateOptionsValidator interface {
|
||||
type SSHCertOptionsValidator interface {
|
||||
SignOption
|
||||
Valid(got SSHOptions) error
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ func (o SSHOptions) Type() uint32 {
|
|||
return sshCertTypeUInt32(o.CertType)
|
||||
}
|
||||
|
||||
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate.
|
||||
// Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
|
||||
func (o SSHOptions) Modify(cert *ssh.Certificate) error {
|
||||
switch o.CertType {
|
||||
case "": // ignore
|
||||
|
@ -116,7 +116,7 @@ func (o SSHOptions) match(got SSHOptions) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertPrincipalsModifier is an SSHCertificateModifier that sets the
|
||||
// sshCertPrincipalsModifier is an SSHCertModifier that sets the
|
||||
// principals to the SSH certificate.
|
||||
type sshCertPrincipalsModifier []string
|
||||
|
||||
|
@ -126,7 +126,7 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertKeyIDModifier is an SSHCertificateModifier that sets the given
|
||||
// sshCertKeyIDModifier is an SSHCertModifier that sets the given
|
||||
// Key ID in the SSH certificate.
|
||||
type sshCertKeyIDModifier string
|
||||
|
||||
|
@ -135,7 +135,7 @@ func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertTypeModifier is an SSHCertificateModifier that sets the
|
||||
// sshCertTypeModifier is an SSHCertModifier that sets the
|
||||
// certificate type.
|
||||
type sshCertTypeModifier string
|
||||
|
||||
|
@ -145,7 +145,7 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertValidAfterModifier is an SSHCertificateModifier that sets the
|
||||
// sshCertValidAfterModifier is an SSHCertModifier that sets the
|
||||
// ValidAfter in the SSH certificate.
|
||||
type sshCertValidAfterModifier uint64
|
||||
|
||||
|
@ -154,7 +154,7 @@ func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertValidBeforeModifier is an SSHCertificateModifier that sets the
|
||||
// sshCertValidBeforeModifier is an SSHCertModifier that sets the
|
||||
// ValidBefore in the SSH certificate.
|
||||
type sshCertValidBeforeModifier uint64
|
||||
|
||||
|
@ -163,11 +163,11 @@ func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshCertDefaultsModifier implements a SSHCertificateModifier that
|
||||
// sshCertDefaultsModifier implements a SSHCertModifier that
|
||||
// modifies the certificate with the given options if they are not set.
|
||||
type sshCertDefaultsModifier SSHOptions
|
||||
|
||||
// Modify implements the SSHCertificateModifier interface.
|
||||
// Modify implements the SSHCertModifier interface.
|
||||
func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error {
|
||||
if cert.CertType == 0 {
|
||||
cert.CertType = sshCertTypeUInt32(m.CertType)
|
||||
|
@ -184,7 +184,7 @@ func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets
|
||||
// sshDefaultExtensionModifier implements an SSHCertModifier that sets
|
||||
// the default extensions in an SSH certificate.
|
||||
type sshDefaultExtensionModifier struct{}
|
||||
|
||||
|
@ -208,14 +208,14 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
|
|||
}
|
||||
}
|
||||
|
||||
// sshDefaultDuration is an SSHCertificateModifier that sets the certificate
|
||||
// sshDefaultDuration is an SSHCertModifier that sets the certificate
|
||||
// ValidAfter and ValidBefore if they have not been set. It will fail if a
|
||||
// CertType has not been set or is not valid.
|
||||
type sshDefaultDuration struct {
|
||||
*Claimer
|
||||
}
|
||||
|
||||
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertificateModifier {
|
||||
func (m *sshDefaultDuration) Option(o SSHOptions) SSHCertModifier {
|
||||
return sshModifierFunc(func(cert *ssh.Certificate) error {
|
||||
d, err := m.DefaultSSHCertDuration(cert.CertType)
|
||||
if err != nil {
|
||||
|
@ -248,7 +248,7 @@ type sshLimitDuration struct {
|
|||
NotAfter time.Time
|
||||
}
|
||||
|
||||
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
|
||||
func (m *sshLimitDuration) Option(o SSHOptions) SSHCertModifier {
|
||||
if m.NotAfter.IsZero() {
|
||||
defaultDuration := &sshDefaultDuration{m.Claimer}
|
||||
return defaultDuration.Option(o)
|
||||
|
@ -295,22 +295,22 @@ func (m *sshLimitDuration) Option(o SSHOptions) SSHCertificateModifier {
|
|||
})
|
||||
}
|
||||
|
||||
// sshCertificateOptionsValidator validates the user SSHOptions with the ones
|
||||
// sshCertOptionsValidator validates the user SSHOptions with the ones
|
||||
// usually present in the token.
|
||||
type sshCertificateOptionsValidator SSHOptions
|
||||
type sshCertOptionsValidator SSHOptions
|
||||
|
||||
// Valid implements SSHCertificateOptionsValidator and returns nil if both
|
||||
// Valid implements SSHCertOptionsValidator and returns nil if both
|
||||
// SSHOptions match.
|
||||
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error {
|
||||
func (v sshCertOptionsValidator) Valid(got SSHOptions) error {
|
||||
want := SSHOptions(v)
|
||||
return want.match(got)
|
||||
}
|
||||
|
||||
type sshCertificateValidityValidator struct {
|
||||
type sshCertValidityValidator struct {
|
||||
*Claimer
|
||||
}
|
||||
|
||||
func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
|
||||
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate) error {
|
||||
switch {
|
||||
case cert.ValidAfter == 0:
|
||||
return errors.New("ssh certificate validAfter cannot be 0")
|
||||
|
@ -355,12 +355,12 @@ func (v *sshCertificateValidityValidator) Valid(cert *ssh.Certificate) error {
|
|||
}
|
||||
}
|
||||
|
||||
// sshCertificateDefaultValidator implements a simple validator for all the
|
||||
// sshCertDefaultValidator implements a simple validator for all the
|
||||
// fields in the SSH certificate.
|
||||
type sshCertificateDefaultValidator struct{}
|
||||
type sshCertDefaultValidator struct{}
|
||||
|
||||
// Valid returns an error if the given certificate does not contain the necessary fields.
|
||||
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
|
||||
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate) error {
|
||||
switch {
|
||||
case len(cert.Nonce) == 0:
|
||||
return errors.New("ssh certificate nonce cannot be empty")
|
||||
|
|
|
@ -489,12 +489,12 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
|
||||
func Test_sshCertDefaultValidator_Valid(t *testing.T) {
|
||||
pub, _, err := keys.GenerateDefaultKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
sshPub, err := ssh.NewPublicKey(pub)
|
||||
assert.FatalError(t, err)
|
||||
v := sshCertificateDefaultValidator{}
|
||||
v := sshCertDefaultValidator{}
|
||||
tests := []struct {
|
||||
name string
|
||||
cert *ssh.Certificate
|
||||
|
@ -670,10 +670,10 @@ func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_sshCertificateValidityValidator(t *testing.T) {
|
||||
func Test_sshCertValidityValidator(t *testing.T) {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
v := sshCertificateValidityValidator{p.claimer}
|
||||
v := sshCertValidityValidator{p.claimer}
|
||||
n := now()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -992,7 +992,7 @@ func Test_sshLimitDuration_Option(t *testing.T) {
|
|||
name string
|
||||
fields fields
|
||||
args args
|
||||
want SSHCertificateModifier
|
||||
want SSHCertModifier
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
|
|
|
@ -45,22 +45,22 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var mods []SSHCertificateModifier
|
||||
var validators []SSHCertificateValidator
|
||||
var mods []SSHCertModifier
|
||||
var validators []SSHCertValidator
|
||||
|
||||
for _, op := range signOpts {
|
||||
switch o := op.(type) {
|
||||
// modify the ssh.Certificate
|
||||
case SSHCertificateModifier:
|
||||
case SSHCertModifier:
|
||||
mods = append(mods, o)
|
||||
// modify the ssh.Certificate given the SSHOptions
|
||||
case SSHCertificateOptionModifier:
|
||||
case SSHCertOptionModifier:
|
||||
mods = append(mods, o.Option(opts))
|
||||
// validate the ssh.Certificate
|
||||
case SSHCertificateValidator:
|
||||
case SSHCertValidator:
|
||||
validators = append(validators, o)
|
||||
// validate the given SSHOptions
|
||||
case SSHCertificateOptionsValidator:
|
||||
case SSHCertOptionsValidator:
|
||||
if err := o.Valid(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -112,20 +112,20 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"sshpop.authorizeToken; error checking checking sshpop cert revocation")
|
||||
} else if isRevoked {
|
||||
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate is revoked"))
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked")
|
||||
}
|
||||
|
||||
// Check validity period of the certificate.
|
||||
n := time.Now()
|
||||
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
|
||||
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future"))
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
|
||||
}
|
||||
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) {
|
||||
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past"))
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
|
||||
}
|
||||
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
|
||||
if !ok {
|
||||
return nil, errs.InternalServerError(errors.New("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey"))
|
||||
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
|
||||
}
|
||||
pubKey := sshCryptoPubKey.CryptoPublicKey()
|
||||
|
||||
|
@ -146,7 +146,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate"))
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate")
|
||||
}
|
||||
|
||||
// Using the ssh certificates key to validate the claims accomplishes two
|
||||
|
@ -170,12 +170,12 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, audiences) {
|
||||
return nil, errs.Unauthorized(errors.Errorf("sshpop.authorizeToken; sshpop token has invalid audience "+
|
||||
"claim (aud): expected %s, but got %s", audiences, claims.Audience))
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token has invalid audience "+
|
||||
"claim (aud): expected %s, but got %s", audiences, claims.Audience)
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errs.Unauthorized(errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty"))
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token subject cannot be empty")
|
||||
}
|
||||
|
||||
claims.sshCert = sshCert
|
||||
|
@ -190,8 +190,8 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
|||
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
|
||||
}
|
||||
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
|
||||
return errs.BadRequest(errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject " +
|
||||
"must be equivalent to sshpop certificate serial number"))
|
||||
return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " +
|
||||
"must be equivalent to sshpop certificate serial number")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -204,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
|
|||
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
|
||||
}
|
||||
if claims.sshCert.CertType != ssh.HostCert {
|
||||
return nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"))
|
||||
return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")
|
||||
}
|
||||
|
||||
return claims.sshCert, nil
|
||||
|
@ -219,15 +219,15 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
|
|||
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
|
||||
}
|
||||
if claims.sshCert.CertType != ssh.HostCert {
|
||||
return nil, nil, errs.BadRequest(errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"))
|
||||
return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")
|
||||
}
|
||||
return claims.sshCert, []SignOption{
|
||||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
|
|
@ -564,8 +564,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
|||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case *sshDefaultPublicKeyValidator:
|
||||
case *sshCertificateDefaultValidator:
|
||||
case *sshCertificateValidityValidator:
|
||||
case *sshCertDefaultValidator:
|
||||
case *sshCertValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||
|
|
|
@ -136,7 +136,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
|
|||
leaf := verifiedChains[0][0]
|
||||
|
||||
if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 {
|
||||
return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature"))
|
||||
return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")
|
||||
}
|
||||
|
||||
// Using the leaf certificates key to validate the claims accomplishes two
|
||||
|
@ -160,12 +160,12 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
|
|||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, audiences) {
|
||||
return nil, errs.Unauthorized(errors.Errorf("x5c.authorizeToken; x5c token has invalid audience "+
|
||||
"claim (aud); expected %s, but got %s", audiences, claims.Audience))
|
||||
return nil, errs.Unauthorized("x5c.authorizeToken; x5c token has invalid audience "+
|
||||
"claim (aud); expected %s, but got %s", audiences, claims.Audience)
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errs.Unauthorized(errors.New("x5c.authorizeToken; x5c token subject cannot be empty"))
|
||||
return nil, errs.Unauthorized("x5c.authorizeToken; x5c token subject cannot be empty")
|
||||
}
|
||||
|
||||
// Save the verified chains on the x5c payload object.
|
||||
|
@ -213,7 +213,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized(errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID()))
|
||||
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -221,7 +221,7 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized(errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID()))
|
||||
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID())
|
||||
}
|
||||
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
|
@ -230,13 +230,13 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
}
|
||||
|
||||
if claims.Step == nil || claims.Step.SSH == nil {
|
||||
return nil, errs.Unauthorized(errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"))
|
||||
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token")
|
||||
}
|
||||
|
||||
opts := claims.Step.SSH
|
||||
signOptions := []SignOption{
|
||||
// validates user's SSHOptions with the ones in the token
|
||||
sshCertificateOptionsValidator(*opts),
|
||||
sshCertOptionsValidator(*opts),
|
||||
}
|
||||
|
||||
// Add modifiers from custom claims
|
||||
|
@ -272,8 +272,8 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate public key.
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertificateValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
&sshCertDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -548,7 +548,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeRevoke(context.TODO(), tc.token); err != nil {
|
||||
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
|
@ -594,7 +594,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeRenew(context.TODO(), nil); err != nil {
|
||||
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
|
@ -754,7 +754,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if opts, err := tc.p.AuthorizeSSHSign(context.TODO(), tc.token); err != nil {
|
||||
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
|
@ -768,7 +768,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
|||
nw := now()
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case sshCertificateOptionsValidator:
|
||||
case sshCertOptionsValidator:
|
||||
tc.claims.Step.SSH.ValidAfter.t = time.Time{}
|
||||
tc.claims.Step.SSH.ValidBefore.t = time.Time{}
|
||||
assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH)
|
||||
|
@ -787,10 +787,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
|||
case *sshLimitDuration:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
|
||||
case *sshCertificateValidityValidator:
|
||||
case *sshCertValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
case *sshDefaultExtensionModifier, *sshDefaultPublicKeyValidator,
|
||||
*sshCertificateDefaultValidator:
|
||||
*sshCertDefaultValidator:
|
||||
case sshCertKeyIDValidator:
|
||||
assert.Equals(t, string(v), "foo")
|
||||
default:
|
||||
|
|
|
@ -2,18 +2,16 @@ package authority
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
||||
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||
key, ok := a.provisioners.LoadEncryptedKey(kid)
|
||||
if !ok {
|
||||
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
return "", errs.NotFound("encrypted key with kid %s was not found", kid)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
@ -30,8 +28,7 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
|
|||
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
|
||||
p, ok := a.provisioners.LoadByCertificate(crt)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("provisioner not found"),
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
return nil, errs.NotFound("provisioner not found")
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
@ -40,8 +37,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
|
|||
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
p, ok := a.provisioners.Load(id)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("provisioner not found"),
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
return nil, errs.NotFound("provisioner not found")
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
|
|
@ -7,13 +7,15 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
func TestGetEncryptedKey(t *testing.T) {
|
||||
type ek struct {
|
||||
a *Authority
|
||||
kid string
|
||||
err *apiError
|
||||
err error
|
||||
code int
|
||||
}
|
||||
tests := map[string]func(t *testing.T) *ek{
|
||||
"ok": func(t *testing.T) *ek {
|
||||
|
@ -34,8 +36,8 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
return &ek{
|
||||
a: a,
|
||||
kid: "foo",
|
||||
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
|
||||
http.StatusNotFound, apiCtx{}},
|
||||
err: errors.New("encrypted key with kid foo was not found"),
|
||||
code: http.StatusNotFound,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -47,14 +49,10 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
ek, err := tc.a.GetEncryptedKey(tc.kid)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch v := err.(type) {
|
||||
case *apiError:
|
||||
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.code, tc.err.code)
|
||||
assert.Equals(t, v.context, tc.err.context)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
}
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
|
|
|
@ -2,23 +2,20 @@ package authority
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// Root returns the certificate corresponding to the given SHA sum argument.
|
||||
func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
||||
val, ok := a.certificates.Load(sum)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum),
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
return nil, errs.NotFound("certificate with fingerprint %s was not found", sum)
|
||||
}
|
||||
|
||||
crt, ok := val.(*x509.Certificate)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
return nil, errs.InternalServer("stored value is not a *x509.Certificate")
|
||||
}
|
||||
return crt, nil
|
||||
}
|
||||
|
@ -52,8 +49,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error)
|
|||
crt, ok := v.(*x509.Certificate)
|
||||
if !ok {
|
||||
federation = nil
|
||||
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
err = errs.InternalServer("stored value is not a *x509.Certificate")
|
||||
return false
|
||||
}
|
||||
federation = append(federation, crt)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
)
|
||||
|
||||
|
@ -17,11 +18,12 @@ func TestRoot(t *testing.T) {
|
|||
|
||||
tests := map[string]struct {
|
||||
sum string
|
||||
err *apiError
|
||||
err error
|
||||
code int
|
||||
}{
|
||||
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}},
|
||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}},
|
||||
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
|
||||
"not-found": {"foo", errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound},
|
||||
"invalid-stored-certificate": {"invaliddata", errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError},
|
||||
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil, http.StatusOK},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
|
@ -29,14 +31,10 @@ func TestRoot(t *testing.T) {
|
|||
crt, err := a.Root(tc.sum)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch v := err.(type) {
|
||||
case *apiError:
|
||||
assert.HasPrefix(t, v.err.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.code, tc.err.code)
|
||||
assert.Equals(t, v.context, tc.err.context)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
}
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
|
|
|
@ -122,7 +122,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
|
|||
// GetSSHConfig returns rendered templates for clients (user) or servers (host).
|
||||
func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
|
||||
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
|
||||
return nil, errs.NotFound(errors.New("getSSHConfig: ssh is not configured"))
|
||||
return nil, errs.NotFound("getSSHConfig: ssh is not configured")
|
||||
}
|
||||
|
||||
var ts []templates.Template
|
||||
|
@ -136,7 +136,7 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
|
|||
ts = a.config.Templates.SSH.Host
|
||||
}
|
||||
default:
|
||||
return nil, errs.BadRequest(errors.Errorf("getSSHConfig: type %s is not valid", typ))
|
||||
return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ)
|
||||
}
|
||||
|
||||
// Merge user and default data
|
||||
|
@ -177,13 +177,13 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error
|
|||
}
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errs.NotFound(errors.New("authority.GetSSHBastion; ssh is not configured"))
|
||||
return nil, errs.NotFound("authority.GetSSHBastion; ssh is not configured")
|
||||
}
|
||||
|
||||
// SignSSH creates a signed SSH certificate with the given public key and options.
|
||||
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
var mods []provisioner.SSHCertificateModifier
|
||||
var validators []provisioner.SSHCertificateValidator
|
||||
var mods []provisioner.SSHCertModifier
|
||||
var validators []provisioner.SSHCertValidator
|
||||
|
||||
// Set backdate with the configured value
|
||||
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
|
||||
|
@ -191,27 +191,27 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
|||
for _, op := range signOpts {
|
||||
switch o := op.(type) {
|
||||
// modify the ssh.Certificate
|
||||
case provisioner.SSHCertificateModifier:
|
||||
case provisioner.SSHCertModifier:
|
||||
mods = append(mods, o)
|
||||
// modify the ssh.Certificate given the SSHOptions
|
||||
case provisioner.SSHCertificateOptionModifier:
|
||||
case provisioner.SSHCertOptionModifier:
|
||||
mods = append(mods, o.Option(opts))
|
||||
// validate the ssh.Certificate
|
||||
case provisioner.SSHCertificateValidator:
|
||||
case provisioner.SSHCertValidator:
|
||||
validators = append(validators, o)
|
||||
// validate the given SSHOptions
|
||||
case provisioner.SSHCertificateOptionsValidator:
|
||||
case provisioner.SSHCertOptionsValidator:
|
||||
if err := o.Valid(opts); err != nil {
|
||||
return nil, errs.Forbidden(err)
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
||||
}
|
||||
default:
|
||||
return nil, errs.InternalServerError(errors.Errorf("signSSH: invalid extra option type %T", o))
|
||||
return nil, errs.InternalServer("signSSH: invalid extra option type %T", o)
|
||||
}
|
||||
}
|
||||
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, errs.InternalServerError(err)
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH")
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
|
@ -228,13 +228,13 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
|||
|
||||
// Use opts to modify the certificate
|
||||
if err := opts.Modify(cert); err != nil {
|
||||
return nil, errs.Forbidden(err)
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
||||
}
|
||||
|
||||
// Use provisioner modifiers
|
||||
for _, m := range mods {
|
||||
if err := m.Modify(cert); err != nil {
|
||||
return nil, errs.Forbidden(err)
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -243,16 +243,16 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
|||
switch cert.CertType {
|
||||
case ssh.UserCert:
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("signSSH: user certificate signing is not enabled"))
|
||||
return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAUserCertSignKey
|
||||
case ssh.HostCert:
|
||||
if a.sshCAHostCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("signSSH: host certificate signing is not enabled"))
|
||||
return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAHostCertSignKey
|
||||
default:
|
||||
return nil, errs.InternalServerError(errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType))
|
||||
return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType)
|
||||
}
|
||||
cert.SignatureKey = signer.PublicKey()
|
||||
|
||||
|
@ -270,7 +270,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
|||
// User provisioners validators
|
||||
for _, v := range validators {
|
||||
if err := v.Valid(cert); err != nil {
|
||||
return nil, errs.Forbidden(err)
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -285,7 +285,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
|||
func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, errs.InternalServerError(err)
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
|
@ -294,7 +294,7 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
|
|||
}
|
||||
|
||||
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
|
||||
return nil, errs.BadRequest(errors.New("rewnewSSH: cannot renew certificate without validity period"))
|
||||
return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period")
|
||||
}
|
||||
|
||||
backdate := a.config.AuthorityConfig.Backdate.Duration
|
||||
|
@ -321,16 +321,16 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
|
|||
switch cert.CertType {
|
||||
case ssh.UserCert:
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("renewSSH: user certificate signing is not enabled"))
|
||||
return nil, errs.NotImplemented("renewSSH: user certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAUserCertSignKey
|
||||
case ssh.HostCert:
|
||||
if a.sshCAHostCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("renewSSH: host certificate signing is not enabled"))
|
||||
return nil, errs.NotImplemented("renewSSH: host certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAHostCertSignKey
|
||||
default:
|
||||
return nil, errs.InternalServerError(errors.Errorf("renewSSH: unexpected ssh certificate type: %d", cert.CertType))
|
||||
return nil, errs.InternalServer("renewSSH: unexpected ssh certificate type: %d", cert.CertType)
|
||||
}
|
||||
cert.SignatureKey = signer.PublicKey()
|
||||
|
||||
|
@ -354,21 +354,21 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
|
|||
|
||||
// RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
|
||||
func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
var validators []provisioner.SSHCertificateValidator
|
||||
var validators []provisioner.SSHCertValidator
|
||||
|
||||
for _, op := range signOpts {
|
||||
switch o := op.(type) {
|
||||
// validate the ssh.Certificate
|
||||
case provisioner.SSHCertificateValidator:
|
||||
case provisioner.SSHCertValidator:
|
||||
validators = append(validators, o)
|
||||
default:
|
||||
return nil, errs.InternalServerError(errors.Errorf("rekeySSH; invalid extra option type %T", o))
|
||||
return nil, errs.InternalServer("rekeySSH; invalid extra option type %T", o)
|
||||
}
|
||||
}
|
||||
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, errs.InternalServerError(err)
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH")
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
|
@ -377,7 +377,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
|
|||
}
|
||||
|
||||
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
|
||||
return nil, errs.BadRequest(errors.New("rekeySSH; cannot rekey certificate without validity period"))
|
||||
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
|
||||
}
|
||||
|
||||
backdate := a.config.AuthorityConfig.Backdate.Duration
|
||||
|
@ -404,16 +404,16 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
|
|||
switch cert.CertType {
|
||||
case ssh.UserCert:
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("rekeySSH; user certificate signing is not enabled"))
|
||||
return nil, errs.NotImplemented("rekeySSH; user certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAUserCertSignKey
|
||||
case ssh.HostCert:
|
||||
if a.sshCAHostCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("rekeySSH; host certificate signing is not enabled"))
|
||||
return nil, errs.NotImplemented("rekeySSH; host certificate signing is not enabled")
|
||||
}
|
||||
signer = a.sshCAHostCertSignKey
|
||||
default:
|
||||
return nil, errs.BadRequest(errors.Errorf("rekeySSH; unexpected ssh certificate type: %d", cert.CertType))
|
||||
return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)
|
||||
}
|
||||
cert.SignatureKey = signer.PublicKey()
|
||||
|
||||
|
@ -431,7 +431,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
|
|||
// Apply validators from provisioner..
|
||||
for _, v := range validators {
|
||||
if err := v.Valid(cert); err != nil {
|
||||
return nil, errs.Forbidden(err)
|
||||
return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -445,18 +445,18 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
|
|||
// SignSSHAddUser signs a certificate that provisions a new user in a server.
|
||||
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, errs.NotImplemented(errors.New("signSSHAddUser: user certificate signing is not enabled"))
|
||||
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
|
||||
}
|
||||
if subject.CertType != ssh.UserCert {
|
||||
return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate is not a user certificate"))
|
||||
return nil, errs.Forbidden("signSSHAddUser: certificate is not a user certificate")
|
||||
}
|
||||
if len(subject.ValidPrincipals) != 1 {
|
||||
return nil, errs.Forbidden(errors.New("signSSHAddUser: certificate does not have only one principal"))
|
||||
return nil, errs.Forbidden("signSSHAddUser: certificate does not have only one principal")
|
||||
}
|
||||
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, errs.InternalServerError(err)
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser")
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
|
|
|
@ -80,7 +80,7 @@ func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
|
|||
|
||||
type sshTestOptionsModifier string
|
||||
|
||||
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier {
|
||||
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertModifier {
|
||||
return sshTestCertModifier(string(m))
|
||||
}
|
||||
|
||||
|
@ -492,12 +492,12 @@ func TestAuthority_CheckSSHHost(t *testing.T) {
|
|||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"true", fields{true, nil}, args{context.TODO(), "foo.internal.com", ""}, true, false},
|
||||
{"false", fields{false, nil}, args{context.TODO(), "foo.internal.com", ""}, false, false},
|
||||
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true},
|
||||
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true},
|
||||
{"internal", fields{false, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true},
|
||||
{"internal", fields{true, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true},
|
||||
{"true", fields{true, nil}, args{context.Background(), "foo.internal.com", ""}, true, false},
|
||||
{"false", fields{false, nil}, args{context.Background(), "foo.internal.com", ""}, false, false},
|
||||
{"notImplemented", fields{false, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
|
||||
{"notImplemented", fields{true, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true},
|
||||
{"internal", fields{false, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
|
||||
{"internal", fields{true, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -61,7 +61,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
|
|||
// Sign creates a signed certificate from a certificate signing request.
|
||||
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
var (
|
||||
opts = []errs.Option{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
|
||||
opts = []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
|
||||
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
|
||||
certValidators = []provisioner.CertificateValidator{}
|
||||
issIdentity = a.intermediateIdentity
|
||||
|
@ -81,7 +81,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
|
|||
case provisioner.ProfileModifier:
|
||||
mods = append(mods, k.Option(signOpts))
|
||||
default:
|
||||
return nil, errs.InternalServerError(errors.Errorf("authority.Sign; invalid extra option type %T", k), opts...)
|
||||
return nil, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]interface{}{k}, opts...)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -131,7 +131,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
|
|||
// Renew creates a new Certificate identical to the old certificate, except
|
||||
// with a validity window that begins 'now'.
|
||||
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||
opts := []errs.Option{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
|
||||
opts := []interface{}{errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String())}
|
||||
|
||||
// Check step provisioner extensions
|
||||
if err := a.authorizeRenew(oldCert); err != nil {
|
||||
|
@ -237,7 +237,7 @@ type RevokeOptions struct {
|
|||
//
|
||||
// TODO: Add OCSP and CRL support.
|
||||
func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error {
|
||||
opts := []errs.Option{
|
||||
opts := []interface{}{
|
||||
errs.WithKeyVal("serialNumber", revokeOpts.Serial),
|
||||
errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode),
|
||||
errs.WithKeyVal("reason", revokeOpts.Reason),
|
||||
|
@ -281,7 +281,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
|||
var ok bool
|
||||
p, ok = a.provisioners.LoadByToken(token, &claims.Claims)
|
||||
if !ok {
|
||||
return errs.InternalServerError(errors.Errorf("authority.Revoke; provisioner not found"), opts...)
|
||||
return errs.InternalServer("authority.Revoke; provisioner not found", opts...)
|
||||
}
|
||||
rci.TokenID, err = p.GetTokenID(revokeOpts.OTT)
|
||||
if err != nil {
|
||||
|
@ -309,10 +309,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
|||
case nil:
|
||||
return nil
|
||||
case db.ErrNotImplemented:
|
||||
return errs.NotImplemented(errors.New("authority.Revoke; no persistence layer configured"), opts...)
|
||||
return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...)
|
||||
case db.ErrAlreadyExists:
|
||||
return errs.BadRequest(errors.Errorf("authority.Revoke; certificate with serial "+
|
||||
"number %s has already been revoked", rci.Serial), opts...)
|
||||
return errs.BadRequest("authority.Revoke; certificate with serial "+
|
||||
"number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...)
|
||||
default:
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
|
||||
}
|
||||
|
|
10
ca/client.go
10
ca/client.go
|
@ -553,7 +553,7 @@ retry:
|
|||
// verify the sha256
|
||||
sum := sha256.Sum256(root.RootPEM.Raw)
|
||||
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) {
|
||||
return nil, errs.BadRequest(errors.New("client.Root; root certificate SHA256 fingerprint do not match"))
|
||||
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
|
||||
}
|
||||
return &root, nil
|
||||
}
|
||||
|
@ -961,8 +961,8 @@ func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrin
|
|||
retry:
|
||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", u,
|
||||
errs.WithMessage("Failed to perform POST request to %s", u))
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed",
|
||||
[]interface{}{u, errs.WithMessage("Failed to perform POST request to %s", u)}...)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if !retried && c.retryOnError(resp) {
|
||||
|
@ -974,8 +974,8 @@ retry:
|
|||
}
|
||||
var check api.SSHCheckPrincipalResponse
|
||||
if err := readJSON(resp.Body, &check); err != nil {
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", u,
|
||||
errs.WithMessage("Failed to parse response from /ssh/check-host endpoint"))
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response",
|
||||
[]interface{}{u, errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")})
|
||||
}
|
||||
return &check, nil
|
||||
}
|
||||
|
|
|
@ -163,8 +163,8 @@ func TestClient_Version(t *testing.T) {
|
|||
expectedErr error
|
||||
}{
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"500", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
|
||||
{"404", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
{"500", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
|
||||
{"404", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -214,7 +214,7 @@ func TestClient_Health(t *testing.T) {
|
|||
expectedErr error
|
||||
}{
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"not ok", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
|
||||
{"not ok", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -268,7 +268,7 @@ func TestClient_Root(t *testing.T) {
|
|||
expectedErr error
|
||||
}{
|
||||
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil},
|
||||
{"not found", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
{"not found", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -336,9 +336,9 @@ func TestClient_Sign(t *testing.T) {
|
|||
expectedErr error
|
||||
}{
|
||||
{"ok", request, ok, 200, false, nil},
|
||||
{"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"empty request", &api.SignRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -409,8 +409,8 @@ func TestClient_Revoke(t *testing.T) {
|
|||
expectedErr error
|
||||
}{
|
||||
{"ok", request, ok, 200, false, nil},
|
||||
{"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -483,9 +483,9 @@ func TestClient_Renew(t *testing.T) {
|
|||
err error
|
||||
}{
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"empty request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"nil request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -533,7 +533,7 @@ func TestClient_Provisioners(t *testing.T) {
|
|||
ok := &api.ProvisionersResponse{
|
||||
Provisioners: provisioner.List{},
|
||||
}
|
||||
internalServerError := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
|
||||
internalServerError := errs.InternalServer("Internal Server Error")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -603,7 +603,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|||
err error
|
||||
}{
|
||||
{"ok", "kid", ok, 200, false, nil},
|
||||
{"fail", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
{"fail", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -665,8 +665,8 @@ func TestClient_Roots(t *testing.T) {
|
|||
err error
|
||||
}{
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"bad-request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -724,7 +724,7 @@ func TestClient_Federation(t *testing.T) {
|
|||
err error
|
||||
}{
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -786,7 +786,7 @@ func TestClient_SSHRoots(t *testing.T) {
|
|||
err error
|
||||
}{
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"not found", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
{"not found", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
@ -869,7 +869,7 @@ func Test_parseEndpoint(t *testing.T) {
|
|||
|
||||
func TestClient_RootFingerprint(t *testing.T) {
|
||||
ok := &api.HealthResponse{Status: "ok"}
|
||||
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
|
||||
nok := errs.InternalServer("Internal Server Error")
|
||||
|
||||
httpsServer := httptest.NewTLSServer(nil)
|
||||
defer httpsServer.Close()
|
||||
|
@ -947,7 +947,7 @@ func TestClient_SSHBastion(t *testing.T) {
|
|||
}{
|
||||
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
|
||||
{"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
|
||||
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
|
|
199
errs/error.go
199
errs/error.go
|
@ -62,31 +62,6 @@ type Error struct {
|
|||
Details map[string]interface{}
|
||||
}
|
||||
|
||||
// New returns a new Error. If the given error implements the StatusCoder
|
||||
// interface we will ignore the given status.
|
||||
func New(status int, err error, opts ...Option) error {
|
||||
var (
|
||||
e *Error
|
||||
ok bool
|
||||
)
|
||||
if e, ok = err.(*Error); !ok {
|
||||
if sc, ok := err.(StatusCoder); ok {
|
||||
e = &Error{Status: sc.StatusCode(), Err: err}
|
||||
} else {
|
||||
cause := errors.Cause(err)
|
||||
if sc, ok := cause.(StatusCoder); ok {
|
||||
e = &Error{Status: sc.StatusCode(), Err: err}
|
||||
} else {
|
||||
e = &Error{Status: status, Err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error in JSON format.
|
||||
type ErrorResponse struct {
|
||||
Status int `json:"status"`
|
||||
|
@ -119,10 +94,11 @@ func (e *Error) Message() string {
|
|||
|
||||
// Wrap returns an error annotating err with a stack trace at the point Wrap is
|
||||
// called, and the supplied message. If err is nil, Wrap returns nil.
|
||||
func Wrap(status int, e error, m string, opts ...Option) error {
|
||||
func Wrap(status int, e error, m string, args ...interface{}) error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
_, opts := splitOptionArgs(args)
|
||||
if err, ok := e.(*Error); ok {
|
||||
err.Err = errors.Wrap(err.Err, m)
|
||||
e = err
|
||||
|
@ -138,25 +114,12 @@ func Wrapf(status int, e error, format string, args ...interface{}) error {
|
|||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
var opts []Option
|
||||
for i, arg := range args {
|
||||
// Once we find the first Option, assume that all further arguments are Options.
|
||||
if _, ok := arg.(Option); ok {
|
||||
for _, a := range args[i:] {
|
||||
// Ignore any arguments after the first Option that are not Options.
|
||||
if opt, ok := a.(Option); ok {
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
args = args[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
as, opts := splitOptionArgs(args)
|
||||
if err, ok := e.(*Error); ok {
|
||||
err.Err = errors.Wrapf(err.Err, format, args...)
|
||||
e = err
|
||||
} else {
|
||||
e = errors.Wrapf(e, format, args...)
|
||||
e = errors.Wrapf(e, format, as...)
|
||||
}
|
||||
return StatusCodeError(status, e, opts...)
|
||||
}
|
||||
|
@ -201,24 +164,24 @@ type Messenger interface {
|
|||
func StatusCodeError(code int, e error, opts ...Option) error {
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
return BadRequest(e, opts...)
|
||||
return BadRequestErr(e, opts...)
|
||||
case http.StatusUnauthorized:
|
||||
return Unauthorized(e, opts...)
|
||||
return UnauthorizedErr(e, opts...)
|
||||
case http.StatusForbidden:
|
||||
return Forbidden(e, opts...)
|
||||
return ForbiddenErr(e, opts...)
|
||||
case http.StatusInternalServerError:
|
||||
return InternalServerError(e, opts...)
|
||||
return InternalServerErr(e, opts...)
|
||||
case http.StatusNotImplemented:
|
||||
return NotImplemented(e, opts...)
|
||||
return NotImplementedErr(e, opts...)
|
||||
default:
|
||||
return UnexpectedError(code, e, opts...)
|
||||
return UnexpectedErr(code, e, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
seeLogs = "Please see the certificate authority logs for more info."
|
||||
// BadRequestDefaultMsg 400 default msg
|
||||
BadRequestDefaultMsg = "The request could not be completed due to being poorly formatted or missing critical data. " + seeLogs
|
||||
BadRequestDefaultMsg = "The request could not be completed; malformed or missing data" + seeLogs
|
||||
// UnauthorizedDefaultMsg 401 default msg
|
||||
UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs
|
||||
// ForbiddenDefaultMsg 403 default msg
|
||||
|
@ -231,46 +194,142 @@ var (
|
|||
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
|
||||
)
|
||||
|
||||
// InternalServerError returns a 500 error with the given error.
|
||||
func InternalServerError(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg))
|
||||
return New(http.StatusInternalServerError, err, opts...)
|
||||
// splitOptionArgs splits the variadic length args into string formatting args
|
||||
// and Option(s) to apply to an Error.
|
||||
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
|
||||
indexOptionStart := -1
|
||||
for i, a := range args {
|
||||
if _, ok := a.(Option); ok {
|
||||
indexOptionStart = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// NotImplemented returns a 501 error with the given error.
|
||||
func NotImplemented(err error, opts ...Option) error {
|
||||
if indexOptionStart < 0 {
|
||||
return args, []Option{}
|
||||
}
|
||||
opts := []Option{}
|
||||
// Ignore any non-Option args that come after the first Option.
|
||||
for _, o := range args[indexOptionStart:] {
|
||||
if opt, ok := o.(Option); ok {
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
return args[:indexOptionStart], opts
|
||||
}
|
||||
|
||||
// NewErr returns a new Error. If the given error implements the StatusCoder
|
||||
// interface we will ignore the given status.
|
||||
func NewErr(status int, err error, opts ...Option) error {
|
||||
var (
|
||||
e *Error
|
||||
ok bool
|
||||
)
|
||||
if e, ok = err.(*Error); !ok {
|
||||
if sc, ok := err.(StatusCoder); ok {
|
||||
e = &Error{Status: sc.StatusCode(), Err: err}
|
||||
} else {
|
||||
cause := errors.Cause(err)
|
||||
if sc, ok := cause.(StatusCoder); ok {
|
||||
e = &Error{Status: sc.StatusCode(), Err: err}
|
||||
} else {
|
||||
e = &Error{Status: status, Err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Errorf creates a new error using the given format and status code.
|
||||
func Errorf(code int, format string, args ...interface{}) error {
|
||||
as, opts := splitOptionArgs(args)
|
||||
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
|
||||
return New(http.StatusNotImplemented, err, opts...)
|
||||
e := &Error{Status: code, Err: fmt.Errorf(format, as...)}
|
||||
for _, o := range opts {
|
||||
o(e)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// BadRequest returns an 400 error with the given error.
|
||||
func BadRequest(err error, opts ...Option) error {
|
||||
// InternalServer creates a 500 error with the given format and arguments.
|
||||
func InternalServer(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg))
|
||||
return Errorf(http.StatusInternalServerError, format, args...)
|
||||
}
|
||||
|
||||
// InternalServerErr returns a 500 error with the given error.
|
||||
func InternalServerErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg))
|
||||
return NewErr(http.StatusInternalServerError, err, opts...)
|
||||
}
|
||||
|
||||
// NotImplemented creates a 501 error with the given format and arguments.
|
||||
func NotImplemented(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(NotImplementedDefaultMsg))
|
||||
return Errorf(http.StatusNotImplemented, format, args...)
|
||||
}
|
||||
|
||||
// NotImplementedErr returns a 501 error with the given error.
|
||||
func NotImplementedErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
|
||||
return NewErr(http.StatusNotImplemented, err, opts...)
|
||||
}
|
||||
|
||||
// BadRequest creates a 400 error with the given format and arguments.
|
||||
func BadRequest(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(BadRequestDefaultMsg))
|
||||
return Errorf(http.StatusBadRequest, format, args...)
|
||||
}
|
||||
|
||||
// BadRequestErr returns an 400 error with the given error.
|
||||
func BadRequestErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
|
||||
return New(http.StatusBadRequest, err, opts...)
|
||||
return NewErr(http.StatusBadRequest, err, opts...)
|
||||
}
|
||||
|
||||
// Unauthorized returns an 401 error with the given error.
|
||||
func Unauthorized(err error, opts ...Option) error {
|
||||
// Unauthorized creates a 401 error with the given format and arguments.
|
||||
func Unauthorized(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(UnauthorizedDefaultMsg))
|
||||
return Errorf(http.StatusUnauthorized, format, args...)
|
||||
}
|
||||
|
||||
// UnauthorizedErr returns an 401 error with the given error.
|
||||
func UnauthorizedErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg))
|
||||
return New(http.StatusUnauthorized, err, opts...)
|
||||
return NewErr(http.StatusUnauthorized, err, opts...)
|
||||
}
|
||||
|
||||
// Forbidden returns an 403 error with the given error.
|
||||
func Forbidden(err error, opts ...Option) error {
|
||||
// Forbidden creates a 403 error with the given format and arguments.
|
||||
func Forbidden(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(ForbiddenDefaultMsg))
|
||||
return Errorf(http.StatusForbidden, format, args...)
|
||||
}
|
||||
|
||||
// ForbiddenErr returns an 403 error with the given error.
|
||||
func ForbiddenErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg))
|
||||
return New(http.StatusForbidden, err, opts...)
|
||||
return NewErr(http.StatusForbidden, err, opts...)
|
||||
}
|
||||
|
||||
// NotFound returns an 404 error with the given error.
|
||||
func NotFound(err error, opts ...Option) error {
|
||||
// NotFound creates a 404 error with the given format and arguments.
|
||||
func NotFound(format string, args ...interface{}) error {
|
||||
args = append(args, withDefaultMessage(NotFoundDefaultMsg))
|
||||
return Errorf(http.StatusNotFound, format, args...)
|
||||
}
|
||||
|
||||
// NotFoundErr returns an 404 error with the given error.
|
||||
func NotFoundErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(NotFoundDefaultMsg))
|
||||
return New(http.StatusNotFound, err, opts...)
|
||||
return NewErr(http.StatusNotFound, err, opts...)
|
||||
}
|
||||
|
||||
// UnexpectedError will be used when the certificate authority makes an outgoing
|
||||
// UnexpectedErr will be used when the certificate authority makes an outgoing
|
||||
// request and receives an unhandled status code.
|
||||
func UnexpectedError(code int, err error, opts ...Option) error {
|
||||
func UnexpectedErr(code int, err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage("The certificate authority received an "+
|
||||
"unexpected HTTP status code - '%d'. "+seeLogs, code))
|
||||
return New(code, err, opts...)
|
||||
return NewErr(code, err, opts...)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue