diff --git a/api/api.go b/api/api.go index f1013c0a..a001dac6 100644 --- a/api/api.go +++ b/api/api.go @@ -26,6 +26,7 @@ import ( // Authority is the interface implemented by a CA authority. type Authority interface { + SSHAuthority // NOTE: Authorize will be deprecated in future releases. Please use the // context specific Authoirize[Sign|Revoke|etc.] methods. Authorize(ott string) ([]provisioner.SignOption, error) @@ -249,6 +250,8 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("GET", "/federation", h.Federation) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) + // SSH CA + r.MethodFunc("POST", "/sign-ssh", h.SignSSH) } // Health is an HTTP handler that returns the status of the server. diff --git a/api/ssh.go b/api/ssh.go new file mode 100644 index 00000000..0e3e4791 --- /dev/null +++ b/api/ssh.go @@ -0,0 +1,150 @@ +package api + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "net/http" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "golang.org/x/crypto/ssh" +) + +// SSHAuthority is the interface implemented by a SSH CA authority. +type SSHAuthority interface { + SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) +} + +// SignSSHRequest is the request body of an SSH certificate request. +type SignSSHRequest struct { + PublicKey []byte `json:"publicKey"` //base64 encoded + OTT string `json:"ott"` + CertType string `json:"certType"` + Principals []string `json:"principals"` + ValidAfter TimeDuration `json:"validAfter"` + ValidBefore TimeDuration `json:"validBefore"` +} + +// SignSSHResponse is the response object that returns the SSH certificate. +type SignSSHResponse struct { + Certificate SSHCertificate `json:"crt"` +} + +// SSHCertificate represents the response SSH certificate. +type SSHCertificate struct { + *ssh.Certificate +} + +// MarshalJSON implements the json.Marshaler interface. The certificate is +// quoted string using the PEM encoding. +func (c SSHCertificate) MarshalJSON() ([]byte, error) { + if c.Certificate == nil { + return []byte("null"), nil + } + s := base64.StdEncoding.EncodeToString(c.Certificate.Marshal()) + return []byte(`"` + s + `"`), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. The certificate is +// expected to be a quoted string using the PEM encoding. +func (c *SSHCertificate) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return errors.Wrap(err, "error decoding certificate") + } + certData, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return errors.Wrap(err, "error decoding certificate") + } + pub, err := ssh.ParsePublicKey(certData) + if err != nil { + return errors.Wrap(err, "error decoding certificate") + } + cert, ok := pub.(*ssh.Certificate) + if !ok { + return errors.Errorf("error decoding certificate: %T is not an *ssh.Certificate", pub) + } + c.Certificate = cert + return nil +} + +// Validate validates the SignSSHRequest. +func (s *SignSSHRequest) Validate() error { + switch { + case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: + return errors.Errorf("unknown certType %s", s.CertType) + case len(s.PublicKey) == 0: + return errors.New("missing or empty publicKey") + case len(s.OTT) == 0: + return errors.New("missing or empty ott") + default: + return nil + } +} + +// ParsePublicKey returns the ssh.PublicKey from the request. +func (s *SignSSHRequest) ParsePublicKey() (ssh.PublicKey, error) { + // Validate pub key. + data := make([]byte, base64.StdEncoding.DecodedLen(len(s.PublicKey))) + if _, err := base64.StdEncoding.Decode(data, s.PublicKey); err != nil { + return nil, errors.Wrap(err, "error decoding publicKey") + } + + // Trim padding from end of key. + data = bytes.TrimRight(data, "\x00") + publicKey, err := ssh.ParsePublicKey(data) + if err != nil { + return nil, errors.Wrap(err, "error parsing publicKey") + } + + return publicKey, nil +} + +// SignSSH is an HTTP handler that reads an SignSSHRequest with a one-time-token +// (ott) from the body and creates a new SSH certificate with the information in +// the request. +func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) { + var body SignSSHRequest + if err := ReadJSON(r.Body, &body); err != nil { + WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) + return + } + + logOtt(w, body.OTT) + if err := body.Validate(); err != nil { + WriteError(w, err) + return + } + + publicKey, err := body.ParsePublicKey() + if err != nil { + WriteError(w, BadRequest(err)) + return + } + + opts := provisioner.SSHOptions{ + CertType: body.CertType, + Principals: body.Principals, + ValidBefore: body.ValidBefore, + ValidAfter: body.ValidAfter, + } + + signOpts, err := h.Authority.AuthorizeSign(body.OTT) + if err != nil { + WriteError(w, Unauthorized(err)) + return + } + + cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...) + if err != nil { + WriteError(w, Forbidden(err)) + return + } + + w.WriteHeader(http.StatusCreated) + // logCertificate(w, cert) + JSON(w, &SignSSHResponse{ + Certificate: SSHCertificate{cert}, + }) +} diff --git a/authority/authority.go b/authority/authority.go index 33340029..c4d9d1cd 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -1,6 +1,7 @@ package authority import ( + "crypto" "crypto/sha256" "crypto/x509" "encoding/hex" @@ -20,6 +21,8 @@ type Authority struct { config *Config rootX509Certs []*x509.Certificate intermediateIdentity *x509util.Identity + sshCAUserCertSignKey crypto.Signer + sshCAHostCertSignKey crypto.Signer validateOnce bool certificates *sync.Map startTime time.Time @@ -117,6 +120,9 @@ func (a *Authority) init() error { } } + a.sshCAHostCertSignKey = a.intermediateIdentity.Key.(crypto.Signer) + a.sshCAUserCertSignKey = a.intermediateIdentity.Key.(crypto.Signer) + // Store all the provisioners for _, p := range a.config.AuthorityConfig.Provisioners { if err := a.provisioners.Store(p); err != nil { diff --git a/authority/config.go b/authority/config.go index 77854812..7cfdf744 100644 --- a/authority/config.go +++ b/authority/config.go @@ -29,10 +29,16 @@ var ( } defaultDisableRenewal = false globalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &provisioner.Duration{Duration: 4 * time.Hour}, + MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, } ) diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index 1109e0c7..23c85002 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -8,10 +8,18 @@ import ( // Claims so that individual provisioners can override global claims. type Claims struct { + // TLS CA properties MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` DisableRenewal *bool `json:"disableRenewal,omitempty"` + // SSH CA properties + MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"` + MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"` + DefaultUserSSHDur *Duration `json:"defaultUserSSHCertDuration,omitempty"` + MinHostSSHDur *Duration `json:"minHostSSHCertDuration,omitempty"` + MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"` + DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"` } // Claimer is the type that controls claims. It provides an interface around the @@ -31,10 +39,16 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) { func (c *Claimer) Claims() Claims { disableRenewal := c.IsDisableRenewal() return Claims{ - MinTLSDur: &Duration{c.MinTLSCertDuration()}, - MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, - DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, - DisableRenewal: &disableRenewal, + MinTLSDur: &Duration{c.MinTLSCertDuration()}, + MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, + DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, + DisableRenewal: &disableRenewal, + MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, + MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, + DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, + MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, + MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, + DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, } } @@ -78,6 +92,66 @@ func (c *Claimer) IsDisableRenewal() bool { return *c.claims.DisableRenewal } +// DefaultUserSSHCertDuration returns the default SSH user cert duration for the +// provisioner. If the default is not set within the provisioner, then the +// global default from the authority configuration will be used. +func (c *Claimer) DefaultUserSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.DefaultUserSSHDur == nil { + return c.global.DefaultUserSSHDur.Duration + } + return c.claims.DefaultUserSSHDur.Duration +} + +// MinUserSSHCertDuration returns the minimum SSH user cert duration for the +// provisioner. If the minimum is not set within the provisioner, then the +// global minimum from the authority configuration will be used. +func (c *Claimer) MinUserSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.MinUserSSHDur == nil { + return c.global.MinUserSSHDur.Duration + } + return c.claims.MinUserSSHDur.Duration +} + +// MaxUserSSHCertDuration returns the maximum SSH user cert duration for the +// provisioner. If the maximum is not set within the provisioner, then the +// global maximum from the authority configuration will be used. +func (c *Claimer) MaxUserSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.MaxUserSSHDur == nil { + return c.global.MaxUserSSHDur.Duration + } + return c.claims.MaxUserSSHDur.Duration +} + +// DefaultHostSSHCertDuration returns the default SSH host cert duration for the +// provisioner. If the default is not set within the provisioner, then the +// global default from the authority configuration will be used. +func (c *Claimer) DefaultHostSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.DefaultHostSSHDur == nil { + return c.global.DefaultHostSSHDur.Duration + } + return c.claims.DefaultHostSSHDur.Duration +} + +// MinHostSSHCertDuration returns the minimum SSH host cert duration for the +// provisioner. If the minimum is not set within the provisioner, then the +// global minimum from the authority configuration will be used. +func (c *Claimer) MinHostSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.MinHostSSHDur == nil { + return c.global.MinHostSSHDur.Duration + } + return c.claims.MinHostSSHDur.Duration +} + +// MaxHostSSHCertDuration returns the maximum SSH Host cert duration for the +// provisioner. If the maximum is not set within the provisioner, then the +// global maximum from the authority configuration will be used. +func (c *Claimer) MaxHostSSHCertDuration() time.Duration { + if c.claims == nil || c.claims.MaxHostSSHDur == nil { + return c.global.MaxHostSSHDur.Duration + } + return c.claims.MaxHostSSHDur.Duration +} + // Validate validates and modifies the Claims with default values. func (c *Claimer) Validate() error { var ( diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index dca5dce9..a21815b9 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -12,7 +12,12 @@ import ( // jwtPayload extends jwt.Claims with step attributes. type jwtPayload struct { jose.Claims - SANs []string `json:"sans,omitempty"` + SANs []string `json:"sans,omitempty"` + Step *stepPayload `json:"step,omitempty"` +} + +type stepPayload struct { + SSH *SSHOptions `json:"ssh,omitempty"` } // JWK is the default provisioner, an entity that can sign tokens necessary for @@ -134,6 +139,12 @@ func (p *JWK) AuthorizeSign(token string) ([]SignOption, error) { if err != nil { return nil, err } + + // Check for SSH token + if claims.Step != nil && claims.Step.SSH != nil { + return p.authorizeSSHSign(claims) + } + // NOTE: This is for backwards compatibility with older versions of cli // and certificates. Older versions added the token subject as the only SAN // in a CSR by default. @@ -159,3 +170,37 @@ func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error { } return nil } + +func (p *JWK) authorizeSSHSign(claims *jwtPayload) ([]SignOption, error) { + t := now() + opts := claims.Step.SSH + signOptions := []SignOption{ + // validates user's SSHOptions with the ones in the token + &sshCertificateOptionsValidator{opts}, + // set the default extensions + &sshDefaultExtensionModifier{}, + // set the key id to the token subject + sshCertificateKeyIDModifier(claims.Subject), + } + + // Add modifiers from custom claims + if opts.CertType != "" { + signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType)) + } + if len(opts.Principals) > 0 { + signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals)) + } + if !opts.ValidAfter.IsZero() { + signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix())) + } + if !opts.ValidBefore.IsZero() { + signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix())) + } + + return append(signOptions, + // checks the validity bounds, and set the validity if has not been set + &sshCertificateValidityModifier{p.claimer}, + // require all the fields in the SSH certificate + &sshCertificateDefaultValidator{}, + ), nil +} diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go new file mode 100644 index 00000000..3f4412ad --- /dev/null +++ b/authority/provisioner/sign_ssh_options.go @@ -0,0 +1,290 @@ +package provisioner + +import ( + "fmt" + "time" + + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" +) + +const ( + // SSHUserCert is the string used to represent ssh.UserCert. + SSHUserCert = "user" + + // SSHHostCert is the string used to represent ssh.HostCert. + SSHHostCert = "host" +) + +// SSHCertificateModifier is the interface used to change properties in an SSH +// certificate. +type SSHCertificateModifier interface { + SignOption + Modify(cert *ssh.Certificate) error +} + +// SSHCertificateOptionModifier is the interface used to add custom options used +// to modify the SSH certificate. +type SSHCertificateOptionModifier interface { + SignOption + Option(o SSHOptions) SSHCertificateModifier +} + +// SSHCertificateValidator is the interface used to validate an SSH certificate. +type SSHCertificateValidator interface { + SignOption + Valid(crt *ssh.Certificate) error +} + +// SSHCertificateOptionsValidator is the interface used to validate the custom +// options used to modify the SSH certificate. +type SSHCertificateOptionsValidator interface { + SignOption + Valid(got SSHOptions) error +} + +// SSHOptions contains the options that can be passed to the SignSSH method. +type SSHOptions struct { + CertType string `json:"certType"` + Principals []string `json:"principals"` + ValidAfter TimeDuration `json:"validAfter,omitempty"` + ValidBefore TimeDuration `json:"validBefore,omitempty"` +} + +// Type returns the uint32 representation of the CertType. +func (o SSHOptions) Type() uint32 { + return sshCertTypeUInt32(o.CertType) +} + +// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate. +func (o SSHOptions) Modify(cert *ssh.Certificate) error { + switch o.CertType { + case "": // ignore + case SSHUserCert: + cert.CertType = ssh.UserCert + case SSHHostCert: + cert.CertType = ssh.HostCert + default: + return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType) + } + cert.ValidPrincipals = o.Principals + if !o.ValidAfter.IsZero() { + cert.ValidAfter = uint64(o.ValidAfter.Time().Unix()) + } + if !o.ValidBefore.IsZero() { + cert.ValidBefore = uint64(o.ValidBefore.Time().Unix()) + } + if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore { + return errors.New("ssh certificate valid after cannot be greater than valid before") + } + return nil +} + +// match compares two SSHOptions and return an error if they don't match. It +// ignores zero values. +func (o SSHOptions) match(got SSHOptions) error { + if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType { + return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType) + } + if len(o.Principals) > 0 && len(got.Principals) > 0 && !equalStringSlice(o.Principals, got.Principals) { + return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals) + } + if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) { + return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter) + } + if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) { + return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore) + } + return nil +} + +// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given +// Key ID in the SSH certificate. +type sshCertificateKeyIDModifier string + +func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error { + cert.KeyId = string(m) + return nil +} + +// sshCertificateCertTypeModifier is an SSHCertificateModifier that sets the +// certificate type to the SSH certificate. +type sshCertificateCertTypeModifier string + +func (m sshCertificateCertTypeModifier) Modify(cert *ssh.Certificate) error { + cert.CertType = sshCertTypeUInt32(string(m)) + return nil +} + +// sshCertificatePrincipalsModifier is an SSHCertificateModifier that sets the +// principals to the SSH certificate. +type sshCertificatePrincipalsModifier []string + +func (m sshCertificatePrincipalsModifier) Modify(cert *ssh.Certificate) error { + cert.ValidPrincipals = []string(m) + return nil +} + +// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the +// ValidAfter in the SSH certificate. +type sshCertificateValidAfterModifier uint64 + +func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error { + cert.ValidAfter = uint64(m) + return nil +} + +// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the +// ValidBefore in the SSH certificate. +type sshCertificateValidBeforeModifier uint64 + +func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error { + cert.ValidBefore = uint64(m) + return nil +} + +// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets +// the default extensions in an SSH certificate. +type sshDefaultExtensionModifier struct{} + +func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error { + if cert.Extensions == nil { + cert.Extensions = make(map[string]string) + } + cert.Extensions["permit-X11-forwarding"] = "" + cert.Extensions["permit-agent-forwarding"] = "" + cert.Extensions["permit-port-forwarding"] = "" + cert.Extensions["permit-pty"] = "" + cert.Extensions["permit-user-rc"] = "" + return nil +} + +// sshCertificateValidityModifier is a SSHCertificateModifier checks the +// validity bounds, setting them if they are not provided. It will fail if a +// CertType has not been set or is not valid. +type sshCertificateValidityModifier struct { + *Claimer +} + +func (m *sshCertificateValidityModifier) Modify(cert *ssh.Certificate) error { + var d, min, max time.Duration + switch cert.CertType { + case ssh.UserCert: + d = m.DefaultUserSSHCertDuration() + min = m.MinUserSSHCertDuration() + max = m.MaxUserSSHCertDuration() + case ssh.HostCert: + d = m.DefaultHostSSHCertDuration() + min = m.MinHostSSHCertDuration() + max = m.MaxHostSSHCertDuration() + case 0: + return errors.New("ssh certificate type has not been set") + default: + return errors.Errorf("unknown ssh certificate type %d", cert.CertType) + } + + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now().Unix()) + } + if cert.ValidBefore == 0 { + t := time.Unix(int64(cert.ValidAfter), 0) + cert.ValidBefore = uint64(t.Add(d).Unix()) + } + + diff := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second + switch { + case diff < max: + return errors.Errorf("ssh certificate duration cannot be lower than %s", min) + case diff > max: + return errors.Errorf("ssh certificate duration cannot be greater than %s", max) + default: + return nil + } +} + +// sshCertificateOptionsValidator validates the user SSHOptions with the ones +// usually present in the token. +type sshCertificateOptionsValidator struct { + *SSHOptions +} + +func (want *sshCertificateOptionsValidator) Valid(got SSHOptions) error { + return want.match(got) +} + +// sshCertificateDefaultValidator implements a simple validator for all the +// fields in the SSH certificate. +type sshCertificateDefaultValidator struct{} + +// Valid returns an error if the given certificate does not contain the necessary fields. +func (v *sshCertificateDefaultValidator) Valid(crt *ssh.Certificate) error { + switch { + case len(crt.Nonce) == 0: + return errors.New("ssh certificate nonce cannot be empty") + case crt.Key == nil: + return errors.New("ssh certificate key cannot be nil") + case crt.Serial == 0: + return errors.New("ssh certificate serial cannot be 0") + case crt.CertType != ssh.UserCert && crt.CertType != ssh.HostCert: + return errors.Errorf("ssh certificate has an unknown type: %d", crt.CertType) + case crt.KeyId == "": + return errors.New("ssh certificate key id cannot be empty") + case len(crt.ValidPrincipals) == 0: + return errors.New("ssh certificate valid principals cannot be empty") + case crt.ValidAfter == 0: + return errors.New("ssh certificate valid after cannot be 0") + case crt.ValidBefore == 0: + return errors.New("ssh certificate valid before cannot be 0") + case len(crt.Extensions) == 0: + return errors.New("ssh certificate extensions cannot be empty") + case crt.SignatureKey == nil: + return errors.New("ssh certificate signature key cannot be nil") + case crt.Signature == nil: + return errors.New("ssh certificate signature cannot be nil") + default: + return nil + } +} + +// sshCertTypeName returns the string representation of the given ssh.CertType. +func sshCertTypeString(ct uint32) string { + switch ct { + case 0: + return "" + case ssh.UserCert: + return SSHUserCert + case ssh.HostCert: + return SSHHostCert + default: + return fmt.Sprintf("unknown (%d)", ct) + } +} + +// sshCertTypeUInt32 +func sshCertTypeUInt32(ct string) uint32 { + switch ct { + case SSHUserCert: + return ssh.UserCert + case SSHHostCert: + return ssh.HostCert + default: + return 0 + } +} + +func equalStringSlice(a, b []string) bool { + var l int + if l = len(a); l != len(b) { + return false + } + visit := make(map[string]struct{}, l) + for i := 0; i < l; i++ { + visit[a[i]] = struct{}{} + } + for i := 0; i < l; i++ { + if _, ok := visit[b[i]]; !ok { + return false + } + } + return true +} diff --git a/authority/provisioner/timeduration.go b/authority/provisioner/timeduration.go index fea967d5..33104df3 100644 --- a/authority/provisioner/timeduration.go +++ b/authority/provisioner/timeduration.go @@ -57,6 +57,17 @@ func (t *TimeDuration) SetTime(tt time.Time) { t.t, t.d = tt, 0 } +// IsZero returns true the TimeDuration represents the zero value, false +// otherwise. +func (t *TimeDuration) IsZero() bool { + return t.t.IsZero() && t.d == 0 +} + +// Equal returns if t and other are equal. +func (t *TimeDuration) Equal(other *TimeDuration) bool { + return t.t.Equal(other.t) && t.d == other.d +} + // MarshalJSON implements the json.Marshaler interface. If the time is set it // will return the time in RFC 3339 format if not it will return the duration // string. diff --git a/authority/ssh.go b/authority/ssh.go new file mode 100644 index 00000000..952811ec --- /dev/null +++ b/authority/ssh.go @@ -0,0 +1,114 @@ +package authority + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "net/http" + "strings" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/crypto/randutil" + "golang.org/x/crypto/ssh" +) + +func generateSSHPublicKeyID(key ssh.PublicKey) string { + sum := sha256.Sum256(key.Marshal()) + return strings.ToLower(hex.EncodeToString(sum[:])) +} + +// SignSSH creates a signed SSH certificate with the given public key and options. +func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + var mods []provisioner.SSHCertificateModifier + var validators []provisioner.SSHCertificateValidator + + for _, op := range signOpts { + switch o := op.(type) { + // modify the ssh.Certificate + case provisioner.SSHCertificateModifier: + mods = append(mods, o) + // modify the ssh.Certificate given the SSHOptions + case provisioner.SSHCertificateOptionModifier: + mods = append(mods, o.Option(opts)) + // validate the ssh.Certificate + case provisioner.SSHCertificateValidator: + validators = append(validators, o) + // validate the given SSHOptions + case provisioner.SSHCertificateOptionsValidator: + if err := o.Valid(opts); err != nil { + return nil, &apiError{err: err, code: http.StatusUnauthorized} + } + default: + return nil, &apiError{ + err: errors.Errorf("signSSH: invalid extra option type %T", o), + code: http.StatusInternalServerError, + } + } + } + + nonce, err := randutil.ASCII(32) + if err != nil { + return nil, err + } + + var serial uint64 + if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { + return nil, errors.Wrap(err, "error reading random number") + } + + // Build base certificate with the key and some random values + cert := &ssh.Certificate{ + Nonce: []byte(nonce), + Key: key, + Serial: serial, + } + + // Use opts to modify the certificate + if err := opts.Modify(cert); err != nil { + return nil, err + } + + // Use provisioner modifiers + for _, m := range mods { + if err := m.Modify(cert); err != nil { + return nil, &apiError{err: err, code: http.StatusInternalServerError} + } + } + + // Get signer from authority keys + var signer ssh.Signer + switch cert.CertType { + case ssh.UserCert: + signer, err = ssh.NewSignerFromSigner(a.sshCAUserCertSignKey) + case ssh.HostCert: + signer, err = ssh.NewSignerFromSigner(a.sshCAHostCertSignKey) + default: + return nil, &apiError{ + err: errors.Errorf("unexpected ssh certificate type: %d", cert.CertType), + code: http.StatusInternalServerError, + } + } + cert.SignatureKey = signer.PublicKey() + + // Get bytes for signing trailing the signature length. + data := cert.Marshal() + data = data[:len(data)-4] + + // Sign the certificate + sig, err := signer.Sign(rand.Reader, data) + if err != nil { + return nil, err + } + cert.Signature = sig + + // User provisioners validators + for _, v := range validators { + if err := v.Valid(cert); err != nil { + return nil, &apiError{err: err, code: http.StatusUnauthorized} + } + } + + return cert, nil +} diff --git a/ca/client.go b/ca/client.go index c9766293..30a43924 100644 --- a/ca/client.go +++ b/ca/client.go @@ -373,6 +373,28 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { return &sign, nil } +// SignSSH performs the SSH certificate sign request to the CA and returns the +// api.SignSSHResponse struct. +func (c *Client) SignSSH(req *api.SignSSHRequest) (*api.SignSSHResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, errors.Wrap(err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: "/sign-ssh"}) + resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "client POST %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var sign api.SignSSHResponse + if err := readJSON(resp.Body, &sign); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &sign, nil +} + // Renew performs the renew request to the CA and returns the api.SignResponse // struct. func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {