Return always http errors in sign ssh options.

This commit is contained in:
Mariano Cano 2021-11-23 17:52:39 -08:00
parent 031d4d7000
commit 1da7ea6646

View file

@ -58,6 +58,11 @@ func (o SignSSHOptions) Validate() error {
if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert { if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert {
return errs.BadRequest("unknown certificate type '%s'", o.CertType) return errs.BadRequest("unknown certificate type '%s'", o.CertType)
} }
for _, p := range o.Principals {
if p == "" {
return errs.BadRequest("principals cannot contain empty values")
}
}
return nil return nil
} }
@ -75,7 +80,7 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
case SSHHostCert: case SSHHostCert:
cert.CertType = ssh.HostCert cert.CertType = ssh.HostCert
default: default:
return errors.Errorf("ssh certificate has an unknown type - %s", o.CertType) return errs.BadRequest("ssh certificate has an unknown type '%s'", o.CertType)
} }
cert.KeyId = o.KeyID cert.KeyId = o.KeyID
@ -95,7 +100,7 @@ func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
cert.ValidBefore = uint64(o.ValidBefore.RelativeTime(t).Unix()) cert.ValidBefore = uint64(o.ValidBefore.RelativeTime(t).Unix())
} }
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore { if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
return errors.New("ssh certificate valid after cannot be greater than valid before") return errs.BadRequest("ssh certificate validAfter cannot be greater than validBefore")
} }
return nil return nil
} }
@ -104,16 +109,16 @@ func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
// ignores zero values. // ignores zero values.
func (o SignSSHOptions) match(got SignSSHOptions) error { func (o SignSSHOptions) match(got SignSSHOptions) error {
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType { if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType) return errs.BadRequest("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
} }
if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) { if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) {
return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals) return errs.BadRequest("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
} }
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) { if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter) return errs.BadRequest("ssh certificate validAfter does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
} }
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) { if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore) return errs.BadRequest("ssh certificate validBefore does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
} }
return nil return nil
} }
@ -206,7 +211,7 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate, _ SignSSHOpt
cert.Extensions["permit-user-rc"] = "" cert.Extensions["permit-user-rc"] = ""
return nil return nil
default: default:
return errors.New("ssh certificate type has not been set or is invalid") return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType)
} }
} }
@ -272,7 +277,7 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error
certValidAfter := time.Unix(int64(cert.ValidAfter), 0) certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
if certValidAfter.After(m.NotAfter) { if certValidAfter.After(m.NotAfter) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)", return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
m.NotAfter, certValidAfter) m.NotAfter, certValidAfter)
} }
@ -285,7 +290,7 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error
} else { } else {
certValidBefore := time.Unix(int64(cert.ValidBefore), 0) certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
if m.NotAfter.Before(certValidBefore) { if m.NotAfter.Before(certValidBefore) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)", return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
m.NotAfter, certValidBefore) m.NotAfter, certValidBefore)
} }
} }
@ -319,11 +324,11 @@ type sshCertOptionsRequireValidator struct {
func (v *sshCertOptionsRequireValidator) Valid(got SignSSHOptions) error { func (v *sshCertOptionsRequireValidator) Valid(got SignSSHOptions) error {
switch { switch {
case v.CertType && got.CertType == "": case v.CertType && got.CertType == "":
return errors.New("ssh certificate certType cannot be empty") return errs.BadRequest("ssh certificate certType cannot be empty")
case v.KeyID && got.KeyID == "": case v.KeyID && got.KeyID == "":
return errors.New("ssh certificate keyID cannot be empty") return errs.BadRequest("ssh certificate keyID cannot be empty")
case v.Principals && len(got.Principals) == 0: case v.Principals && len(got.Principals) == 0:
return errors.New("ssh certificate principals cannot be empty") return errs.BadRequest("ssh certificate principals cannot be empty")
default: default:
return nil return nil
} }
@ -354,7 +359,7 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
case 0: case 0:
return errs.BadRequest("ssh certificate type has not been set") return errs.BadRequest("ssh certificate type has not been set")
default: default:
return errs.BadRequest("unknown ssh certificate type %d", cert.CertType) return errs.BadRequest("unknown ssh certificate type '%d'", cert.CertType)
} }
// To not take into account the backdate, time.Now() will be used to // To not take into account the backdate, time.Now() will be used to
@ -381,25 +386,25 @@ type sshCertDefaultValidator struct{}
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error { func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
switch { switch {
case len(cert.Nonce) == 0: case len(cert.Nonce) == 0:
return errors.New("ssh certificate nonce cannot be empty") return errs.Forbidden("ssh certificate nonce cannot be empty")
case cert.Key == nil: case cert.Key == nil:
return errors.New("ssh certificate key cannot be nil") return errs.Forbidden("ssh certificate key cannot be nil")
case cert.Serial == 0: case cert.Serial == 0:
return errors.New("ssh certificate serial cannot be 0") return errs.Forbidden("ssh certificate serial cannot be 0")
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert: case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
return errors.Errorf("ssh certificate has an unknown type: %d", cert.CertType) return errs.Forbidden("ssh certificate has an unknown type '%d'", cert.CertType)
case cert.KeyId == "": case cert.KeyId == "":
return errors.New("ssh certificate key id cannot be empty") return errs.Forbidden("ssh certificate key id cannot be empty")
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0") return errs.Forbidden("ssh certificate validAfter cannot be 0")
case cert.ValidBefore < uint64(now().Unix()): case cert.ValidBefore < uint64(now().Unix()):
return errors.New("ssh certificate validBefore cannot be in the past") return errs.Forbidden("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter: case cert.ValidBefore < cert.ValidAfter:
return errors.New("ssh certificate validBefore cannot be before validAfter") return errs.Forbidden("ssh certificate validBefore cannot be before validAfter")
case cert.SignatureKey == nil: case cert.SignatureKey == nil:
return errors.New("ssh certificate signature key cannot be nil") return errs.Forbidden("ssh certificate signature key cannot be nil")
case cert.Signature == nil: case cert.Signature == nil:
return errors.New("ssh certificate signature cannot be nil") return errs.Forbidden("ssh certificate signature cannot be nil")
default: default:
return nil return nil
} }
@ -411,25 +416,25 @@ type sshDefaultPublicKeyValidator struct{}
// Valid checks that certificate request common name matches the one configured. // Valid checks that certificate request common name matches the one configured.
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error { func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
if cert.Key == nil { if cert.Key == nil {
return errors.New("ssh certificate key cannot be nil") return errs.BadRequest("ssh certificate key cannot be nil")
} }
switch cert.Key.Type() { switch cert.Key.Type() {
case ssh.KeyAlgoRSA: case ssh.KeyAlgoRSA:
_, in, ok := sshParseString(cert.Key.Marshal()) _, in, ok := sshParseString(cert.Key.Marshal())
if !ok { if !ok {
return errors.New("ssh certificate key is invalid") return errs.BadRequest("ssh certificate key is invalid")
} }
key, err := sshParseRSAPublicKey(in) key, err := sshParseRSAPublicKey(in)
if err != nil { if err != nil {
return err return errs.BadRequestErr(err, "error parsing public key")
} }
if key.Size() < keyutil.MinRSAKeyBytes { if key.Size() < keyutil.MinRSAKeyBytes {
return errors.Errorf("ssh certificate key must be at least %d bits (%d bytes)", return errs.Forbidden("ssh certificate key must be at least %d bits (%d bytes)",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes) 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)
} }
return nil return nil
case ssh.KeyAlgoDSA: case ssh.KeyAlgoDSA:
return errors.New("ssh certificate key algorithm (DSA) is not supported") return errs.BadRequest("ssh certificate key algorithm (DSA) is not supported")
default: default:
return nil return nil
} }