Add initial implementation of an SSH CA using the JWK provisioner.
Fixes smallstep/ca-component#187
This commit is contained in:
parent
5356bce4d8
commit
1c8f610ca9
10 changed files with 730 additions and 9 deletions
|
@ -26,6 +26,7 @@ import (
|
||||||
|
|
||||||
// Authority is the interface implemented by a CA authority.
|
// Authority is the interface implemented by a CA authority.
|
||||||
type Authority interface {
|
type Authority interface {
|
||||||
|
SSHAuthority
|
||||||
// NOTE: Authorize will be deprecated in future releases. Please use the
|
// NOTE: Authorize will be deprecated in future releases. Please use the
|
||||||
// context specific Authoirize[Sign|Revoke|etc.] methods.
|
// context specific Authoirize[Sign|Revoke|etc.] methods.
|
||||||
Authorize(ott string) ([]provisioner.SignOption, error)
|
Authorize(ott string) ([]provisioner.SignOption, error)
|
||||||
|
@ -249,6 +250,8 @@ func (h *caHandler) Route(r Router) {
|
||||||
r.MethodFunc("GET", "/federation", h.Federation)
|
r.MethodFunc("GET", "/federation", h.Federation)
|
||||||
// For compatibility with old code:
|
// For compatibility with old code:
|
||||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||||
|
// SSH CA
|
||||||
|
r.MethodFunc("POST", "/sign-ssh", h.SignSSH)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Health is an HTTP handler that returns the status of the server.
|
// Health is an HTTP handler that returns the status of the server.
|
||||||
|
|
150
api/ssh.go
Normal file
150
api/ssh.go
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSHAuthority is the interface implemented by a SSH CA authority.
|
||||||
|
type SSHAuthority interface {
|
||||||
|
SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSHRequest is the request body of an SSH certificate request.
|
||||||
|
type SignSSHRequest struct {
|
||||||
|
PublicKey []byte `json:"publicKey"` //base64 encoded
|
||||||
|
OTT string `json:"ott"`
|
||||||
|
CertType string `json:"certType"`
|
||||||
|
Principals []string `json:"principals"`
|
||||||
|
ValidAfter TimeDuration `json:"validAfter"`
|
||||||
|
ValidBefore TimeDuration `json:"validBefore"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSHResponse is the response object that returns the SSH certificate.
|
||||||
|
type SignSSHResponse struct {
|
||||||
|
Certificate SSHCertificate `json:"crt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificate represents the response SSH certificate.
|
||||||
|
type SSHCertificate struct {
|
||||||
|
*ssh.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface. The certificate is
|
||||||
|
// quoted string using the PEM encoding.
|
||||||
|
func (c SSHCertificate) MarshalJSON() ([]byte, error) {
|
||||||
|
if c.Certificate == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
s := base64.StdEncoding.EncodeToString(c.Certificate.Marshal())
|
||||||
|
return []byte(`"` + s + `"`), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface. The certificate is
|
||||||
|
// expected to be a quoted string using the PEM encoding.
|
||||||
|
func (c *SSHCertificate) UnmarshalJSON(data []byte) error {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return errors.Wrap(err, "error decoding certificate")
|
||||||
|
}
|
||||||
|
certData, err := base64.StdEncoding.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error decoding certificate")
|
||||||
|
}
|
||||||
|
pub, err := ssh.ParsePublicKey(certData)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error decoding certificate")
|
||||||
|
}
|
||||||
|
cert, ok := pub.(*ssh.Certificate)
|
||||||
|
if !ok {
|
||||||
|
return errors.Errorf("error decoding certificate: %T is not an *ssh.Certificate", pub)
|
||||||
|
}
|
||||||
|
c.Certificate = cert
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the SignSSHRequest.
|
||||||
|
func (s *SignSSHRequest) Validate() error {
|
||||||
|
switch {
|
||||||
|
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
|
||||||
|
return errors.Errorf("unknown certType %s", s.CertType)
|
||||||
|
case len(s.PublicKey) == 0:
|
||||||
|
return errors.New("missing or empty publicKey")
|
||||||
|
case len(s.OTT) == 0:
|
||||||
|
return errors.New("missing or empty ott")
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePublicKey returns the ssh.PublicKey from the request.
|
||||||
|
func (s *SignSSHRequest) ParsePublicKey() (ssh.PublicKey, error) {
|
||||||
|
// Validate pub key.
|
||||||
|
data := make([]byte, base64.StdEncoding.DecodedLen(len(s.PublicKey)))
|
||||||
|
if _, err := base64.StdEncoding.Decode(data, s.PublicKey); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error decoding publicKey")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trim padding from end of key.
|
||||||
|
data = bytes.TrimRight(data, "\x00")
|
||||||
|
publicKey, err := ssh.ParsePublicKey(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error parsing publicKey")
|
||||||
|
}
|
||||||
|
|
||||||
|
return publicKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSH is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||||
|
// (ott) from the body and creates a new SSH certificate with the information in
|
||||||
|
// the request.
|
||||||
|
func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var body SignSSHRequest
|
||||||
|
if err := ReadJSON(r.Body, &body); err != nil {
|
||||||
|
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logOtt(w, body.OTT)
|
||||||
|
if err := body.Validate(); err != nil {
|
||||||
|
WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey, err := body.ParsePublicKey()
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, BadRequest(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := provisioner.SSHOptions{
|
||||||
|
CertType: body.CertType,
|
||||||
|
Principals: body.Principals,
|
||||||
|
ValidBefore: body.ValidBefore,
|
||||||
|
ValidAfter: body.ValidAfter,
|
||||||
|
}
|
||||||
|
|
||||||
|
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, Unauthorized(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, Forbidden(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
// logCertificate(w, cert)
|
||||||
|
JSON(w, &SignSSHResponse{
|
||||||
|
Certificate: SSHCertificate{cert},
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
@ -20,6 +21,8 @@ type Authority struct {
|
||||||
config *Config
|
config *Config
|
||||||
rootX509Certs []*x509.Certificate
|
rootX509Certs []*x509.Certificate
|
||||||
intermediateIdentity *x509util.Identity
|
intermediateIdentity *x509util.Identity
|
||||||
|
sshCAUserCertSignKey crypto.Signer
|
||||||
|
sshCAHostCertSignKey crypto.Signer
|
||||||
validateOnce bool
|
validateOnce bool
|
||||||
certificates *sync.Map
|
certificates *sync.Map
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
|
@ -117,6 +120,9 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
a.sshCAHostCertSignKey = a.intermediateIdentity.Key.(crypto.Signer)
|
||||||
|
a.sshCAUserCertSignKey = a.intermediateIdentity.Key.(crypto.Signer)
|
||||||
|
|
||||||
// Store all the provisioners
|
// Store all the provisioners
|
||||||
for _, p := range a.config.AuthorityConfig.Provisioners {
|
for _, p := range a.config.AuthorityConfig.Provisioners {
|
||||||
if err := a.provisioners.Store(p); err != nil {
|
if err := a.provisioners.Store(p); err != nil {
|
||||||
|
|
|
@ -29,10 +29,16 @@ var (
|
||||||
}
|
}
|
||||||
defaultDisableRenewal = false
|
defaultDisableRenewal = false
|
||||||
globalProvisionerClaims = provisioner.Claims{
|
globalProvisionerClaims = provisioner.Claims{
|
||||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
DisableRenewal: &defaultDisableRenewal,
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||||
|
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
|
DefaultUserSSHDur: &provisioner.Duration{Duration: 4 * time.Hour},
|
||||||
|
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||||
|
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||||
|
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -8,10 +8,18 @@ import (
|
||||||
|
|
||||||
// Claims so that individual provisioners can override global claims.
|
// Claims so that individual provisioners can override global claims.
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
|
// TLS CA properties
|
||||||
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
||||||
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
||||||
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
||||||
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||||
|
// SSH CA properties
|
||||||
|
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
|
||||||
|
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
|
||||||
|
DefaultUserSSHDur *Duration `json:"defaultUserSSHCertDuration,omitempty"`
|
||||||
|
MinHostSSHDur *Duration `json:"minHostSSHCertDuration,omitempty"`
|
||||||
|
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
|
||||||
|
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claimer is the type that controls claims. It provides an interface around the
|
// Claimer is the type that controls claims. It provides an interface around the
|
||||||
|
@ -35,6 +43,12 @@ func (c *Claimer) Claims() Claims {
|
||||||
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
||||||
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
||||||
DisableRenewal: &disableRenewal,
|
DisableRenewal: &disableRenewal,
|
||||||
|
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
|
||||||
|
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
|
||||||
|
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
|
||||||
|
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
|
||||||
|
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
|
||||||
|
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +92,66 @@ func (c *Claimer) IsDisableRenewal() bool {
|
||||||
return *c.claims.DisableRenewal
|
return *c.claims.DisableRenewal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultUserSSHCertDuration returns the default SSH user cert duration for the
|
||||||
|
// provisioner. If the default is not set within the provisioner, then the
|
||||||
|
// global default from the authority configuration will be used.
|
||||||
|
func (c *Claimer) DefaultUserSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.DefaultUserSSHDur == nil {
|
||||||
|
return c.global.DefaultUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.DefaultUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinUserSSHCertDuration returns the minimum SSH user cert duration for the
|
||||||
|
// provisioner. If the minimum is not set within the provisioner, then the
|
||||||
|
// global minimum from the authority configuration will be used.
|
||||||
|
func (c *Claimer) MinUserSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.MinUserSSHDur == nil {
|
||||||
|
return c.global.MinUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.MinUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxUserSSHCertDuration returns the maximum SSH user cert duration for the
|
||||||
|
// provisioner. If the maximum is not set within the provisioner, then the
|
||||||
|
// global maximum from the authority configuration will be used.
|
||||||
|
func (c *Claimer) MaxUserSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.MaxUserSSHDur == nil {
|
||||||
|
return c.global.MaxUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.MaxUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHostSSHCertDuration returns the default SSH host cert duration for the
|
||||||
|
// provisioner. If the default is not set within the provisioner, then the
|
||||||
|
// global default from the authority configuration will be used.
|
||||||
|
func (c *Claimer) DefaultHostSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.DefaultHostSSHDur == nil {
|
||||||
|
return c.global.DefaultHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.DefaultHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinHostSSHCertDuration returns the minimum SSH host cert duration for the
|
||||||
|
// provisioner. If the minimum is not set within the provisioner, then the
|
||||||
|
// global minimum from the authority configuration will be used.
|
||||||
|
func (c *Claimer) MinHostSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.MinHostSSHDur == nil {
|
||||||
|
return c.global.MinHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.MinHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxHostSSHCertDuration returns the maximum SSH Host cert duration for the
|
||||||
|
// provisioner. If the maximum is not set within the provisioner, then the
|
||||||
|
// global maximum from the authority configuration will be used.
|
||||||
|
func (c *Claimer) MaxHostSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.MaxHostSSHDur == nil {
|
||||||
|
return c.global.MaxHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.MaxHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
// Validate validates and modifies the Claims with default values.
|
// Validate validates and modifies the Claims with default values.
|
||||||
func (c *Claimer) Validate() error {
|
func (c *Claimer) Validate() error {
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -13,6 +13,11 @@ import (
|
||||||
type jwtPayload struct {
|
type jwtPayload struct {
|
||||||
jose.Claims
|
jose.Claims
|
||||||
SANs []string `json:"sans,omitempty"`
|
SANs []string `json:"sans,omitempty"`
|
||||||
|
Step *stepPayload `json:"step,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type stepPayload struct {
|
||||||
|
SSH *SSHOptions `json:"ssh,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWK is the default provisioner, an entity that can sign tokens necessary for
|
// JWK is the default provisioner, an entity that can sign tokens necessary for
|
||||||
|
@ -134,6 +139,12 @@ func (p *JWK) AuthorizeSign(token string) ([]SignOption, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for SSH token
|
||||||
|
if claims.Step != nil && claims.Step.SSH != nil {
|
||||||
|
return p.authorizeSSHSign(claims)
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: This is for backwards compatibility with older versions of cli
|
// NOTE: This is for backwards compatibility with older versions of cli
|
||||||
// and certificates. Older versions added the token subject as the only SAN
|
// and certificates. Older versions added the token subject as the only SAN
|
||||||
// in a CSR by default.
|
// in a CSR by default.
|
||||||
|
@ -159,3 +170,37 @@ func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *JWK) authorizeSSHSign(claims *jwtPayload) ([]SignOption, error) {
|
||||||
|
t := now()
|
||||||
|
opts := claims.Step.SSH
|
||||||
|
signOptions := []SignOption{
|
||||||
|
// validates user's SSHOptions with the ones in the token
|
||||||
|
&sshCertificateOptionsValidator{opts},
|
||||||
|
// set the default extensions
|
||||||
|
&sshDefaultExtensionModifier{},
|
||||||
|
// set the key id to the token subject
|
||||||
|
sshCertificateKeyIDModifier(claims.Subject),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add modifiers from custom claims
|
||||||
|
if opts.CertType != "" {
|
||||||
|
signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType))
|
||||||
|
}
|
||||||
|
if len(opts.Principals) > 0 {
|
||||||
|
signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals))
|
||||||
|
}
|
||||||
|
if !opts.ValidAfter.IsZero() {
|
||||||
|
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
|
||||||
|
}
|
||||||
|
if !opts.ValidBefore.IsZero() {
|
||||||
|
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(signOptions,
|
||||||
|
// checks the validity bounds, and set the validity if has not been set
|
||||||
|
&sshCertificateValidityModifier{p.claimer},
|
||||||
|
// require all the fields in the SSH certificate
|
||||||
|
&sshCertificateDefaultValidator{},
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
290
authority/provisioner/sign_ssh_options.go
Normal file
290
authority/provisioner/sign_ssh_options.go
Normal file
|
@ -0,0 +1,290 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SSHUserCert is the string used to represent ssh.UserCert.
|
||||||
|
SSHUserCert = "user"
|
||||||
|
|
||||||
|
// SSHHostCert is the string used to represent ssh.HostCert.
|
||||||
|
SSHHostCert = "host"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSHCertificateModifier is the interface used to change properties in an SSH
|
||||||
|
// certificate.
|
||||||
|
type SSHCertificateModifier interface {
|
||||||
|
SignOption
|
||||||
|
Modify(cert *ssh.Certificate) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificateOptionModifier is the interface used to add custom options used
|
||||||
|
// to modify the SSH certificate.
|
||||||
|
type SSHCertificateOptionModifier interface {
|
||||||
|
SignOption
|
||||||
|
Option(o SSHOptions) SSHCertificateModifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificateValidator is the interface used to validate an SSH certificate.
|
||||||
|
type SSHCertificateValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(crt *ssh.Certificate) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificateOptionsValidator is the interface used to validate the custom
|
||||||
|
// options used to modify the SSH certificate.
|
||||||
|
type SSHCertificateOptionsValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(got SSHOptions) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHOptions contains the options that can be passed to the SignSSH method.
|
||||||
|
type SSHOptions struct {
|
||||||
|
CertType string `json:"certType"`
|
||||||
|
Principals []string `json:"principals"`
|
||||||
|
ValidAfter TimeDuration `json:"validAfter,omitempty"`
|
||||||
|
ValidBefore TimeDuration `json:"validBefore,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the uint32 representation of the CertType.
|
||||||
|
func (o SSHOptions) Type() uint32 {
|
||||||
|
return sshCertTypeUInt32(o.CertType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate.
|
||||||
|
func (o SSHOptions) Modify(cert *ssh.Certificate) error {
|
||||||
|
switch o.CertType {
|
||||||
|
case "": // ignore
|
||||||
|
case SSHUserCert:
|
||||||
|
cert.CertType = ssh.UserCert
|
||||||
|
case SSHHostCert:
|
||||||
|
cert.CertType = ssh.HostCert
|
||||||
|
default:
|
||||||
|
return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType)
|
||||||
|
}
|
||||||
|
cert.ValidPrincipals = o.Principals
|
||||||
|
if !o.ValidAfter.IsZero() {
|
||||||
|
cert.ValidAfter = uint64(o.ValidAfter.Time().Unix())
|
||||||
|
}
|
||||||
|
if !o.ValidBefore.IsZero() {
|
||||||
|
cert.ValidBefore = uint64(o.ValidBefore.Time().Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
|
||||||
|
return errors.New("ssh certificate valid after cannot be greater than valid before")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// match compares two SSHOptions and return an error if they don't match. It
|
||||||
|
// ignores zero values.
|
||||||
|
func (o SSHOptions) match(got SSHOptions) error {
|
||||||
|
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
|
||||||
|
return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
|
||||||
|
}
|
||||||
|
if len(o.Principals) > 0 && len(got.Principals) > 0 && !equalStringSlice(o.Principals, got.Principals) {
|
||||||
|
return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
|
||||||
|
}
|
||||||
|
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
|
||||||
|
return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
|
||||||
|
}
|
||||||
|
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
|
||||||
|
return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given
|
||||||
|
// Key ID in the SSH certificate.
|
||||||
|
type sshCertificateKeyIDModifier string
|
||||||
|
|
||||||
|
func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.KeyId = string(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateCertTypeModifier is an SSHCertificateModifier that sets the
|
||||||
|
// certificate type to the SSH certificate.
|
||||||
|
type sshCertificateCertTypeModifier string
|
||||||
|
|
||||||
|
func (m sshCertificateCertTypeModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.CertType = sshCertTypeUInt32(string(m))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificatePrincipalsModifier is an SSHCertificateModifier that sets the
|
||||||
|
// principals to the SSH certificate.
|
||||||
|
type sshCertificatePrincipalsModifier []string
|
||||||
|
|
||||||
|
func (m sshCertificatePrincipalsModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.ValidPrincipals = []string(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the
|
||||||
|
// ValidAfter in the SSH certificate.
|
||||||
|
type sshCertificateValidAfterModifier uint64
|
||||||
|
|
||||||
|
func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.ValidAfter = uint64(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the
|
||||||
|
// ValidBefore in the SSH certificate.
|
||||||
|
type sshCertificateValidBeforeModifier uint64
|
||||||
|
|
||||||
|
func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.ValidBefore = uint64(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets
|
||||||
|
// the default extensions in an SSH certificate.
|
||||||
|
type sshDefaultExtensionModifier struct{}
|
||||||
|
|
||||||
|
func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
if cert.Extensions == nil {
|
||||||
|
cert.Extensions = make(map[string]string)
|
||||||
|
}
|
||||||
|
cert.Extensions["permit-X11-forwarding"] = ""
|
||||||
|
cert.Extensions["permit-agent-forwarding"] = ""
|
||||||
|
cert.Extensions["permit-port-forwarding"] = ""
|
||||||
|
cert.Extensions["permit-pty"] = ""
|
||||||
|
cert.Extensions["permit-user-rc"] = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateValidityModifier is a SSHCertificateModifier checks the
|
||||||
|
// validity bounds, setting them if they are not provided. It will fail if a
|
||||||
|
// CertType has not been set or is not valid.
|
||||||
|
type sshCertificateValidityModifier struct {
|
||||||
|
*Claimer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *sshCertificateValidityModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
var d, min, max time.Duration
|
||||||
|
switch cert.CertType {
|
||||||
|
case ssh.UserCert:
|
||||||
|
d = m.DefaultUserSSHCertDuration()
|
||||||
|
min = m.MinUserSSHCertDuration()
|
||||||
|
max = m.MaxUserSSHCertDuration()
|
||||||
|
case ssh.HostCert:
|
||||||
|
d = m.DefaultHostSSHCertDuration()
|
||||||
|
min = m.MinHostSSHCertDuration()
|
||||||
|
max = m.MaxHostSSHCertDuration()
|
||||||
|
case 0:
|
||||||
|
return errors.New("ssh certificate type has not been set")
|
||||||
|
default:
|
||||||
|
return errors.Errorf("unknown ssh certificate type %d", cert.CertType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cert.ValidAfter == 0 {
|
||||||
|
cert.ValidAfter = uint64(now().Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidBefore == 0 {
|
||||||
|
t := time.Unix(int64(cert.ValidAfter), 0)
|
||||||
|
cert.ValidBefore = uint64(t.Add(d).Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
diff := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
|
||||||
|
switch {
|
||||||
|
case diff < max:
|
||||||
|
return errors.Errorf("ssh certificate duration cannot be lower than %s", min)
|
||||||
|
case diff > max:
|
||||||
|
return errors.Errorf("ssh certificate duration cannot be greater than %s", max)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateOptionsValidator validates the user SSHOptions with the ones
|
||||||
|
// usually present in the token.
|
||||||
|
type sshCertificateOptionsValidator struct {
|
||||||
|
*SSHOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (want *sshCertificateOptionsValidator) Valid(got SSHOptions) error {
|
||||||
|
return want.match(got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateDefaultValidator implements a simple validator for all the
|
||||||
|
// fields in the SSH certificate.
|
||||||
|
type sshCertificateDefaultValidator struct{}
|
||||||
|
|
||||||
|
// Valid returns an error if the given certificate does not contain the necessary fields.
|
||||||
|
func (v *sshCertificateDefaultValidator) Valid(crt *ssh.Certificate) error {
|
||||||
|
switch {
|
||||||
|
case len(crt.Nonce) == 0:
|
||||||
|
return errors.New("ssh certificate nonce cannot be empty")
|
||||||
|
case crt.Key == nil:
|
||||||
|
return errors.New("ssh certificate key cannot be nil")
|
||||||
|
case crt.Serial == 0:
|
||||||
|
return errors.New("ssh certificate serial cannot be 0")
|
||||||
|
case crt.CertType != ssh.UserCert && crt.CertType != ssh.HostCert:
|
||||||
|
return errors.Errorf("ssh certificate has an unknown type: %d", crt.CertType)
|
||||||
|
case crt.KeyId == "":
|
||||||
|
return errors.New("ssh certificate key id cannot be empty")
|
||||||
|
case len(crt.ValidPrincipals) == 0:
|
||||||
|
return errors.New("ssh certificate valid principals cannot be empty")
|
||||||
|
case crt.ValidAfter == 0:
|
||||||
|
return errors.New("ssh certificate valid after cannot be 0")
|
||||||
|
case crt.ValidBefore == 0:
|
||||||
|
return errors.New("ssh certificate valid before cannot be 0")
|
||||||
|
case len(crt.Extensions) == 0:
|
||||||
|
return errors.New("ssh certificate extensions cannot be empty")
|
||||||
|
case crt.SignatureKey == nil:
|
||||||
|
return errors.New("ssh certificate signature key cannot be nil")
|
||||||
|
case crt.Signature == nil:
|
||||||
|
return errors.New("ssh certificate signature cannot be nil")
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertTypeName returns the string representation of the given ssh.CertType.
|
||||||
|
func sshCertTypeString(ct uint32) string {
|
||||||
|
switch ct {
|
||||||
|
case 0:
|
||||||
|
return ""
|
||||||
|
case ssh.UserCert:
|
||||||
|
return SSHUserCert
|
||||||
|
case ssh.HostCert:
|
||||||
|
return SSHHostCert
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown (%d)", ct)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertTypeUInt32
|
||||||
|
func sshCertTypeUInt32(ct string) uint32 {
|
||||||
|
switch ct {
|
||||||
|
case SSHUserCert:
|
||||||
|
return ssh.UserCert
|
||||||
|
case SSHHostCert:
|
||||||
|
return ssh.HostCert
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalStringSlice(a, b []string) bool {
|
||||||
|
var l int
|
||||||
|
if l = len(a); l != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
visit := make(map[string]struct{}, l)
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
visit[a[i]] = struct{}{}
|
||||||
|
}
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
if _, ok := visit[b[i]]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
|
@ -57,6 +57,17 @@ func (t *TimeDuration) SetTime(tt time.Time) {
|
||||||
t.t, t.d = tt, 0
|
t.t, t.d = tt, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsZero returns true the TimeDuration represents the zero value, false
|
||||||
|
// otherwise.
|
||||||
|
func (t *TimeDuration) IsZero() bool {
|
||||||
|
return t.t.IsZero() && t.d == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equal returns if t and other are equal.
|
||||||
|
func (t *TimeDuration) Equal(other *TimeDuration) bool {
|
||||||
|
return t.t.Equal(other.t) && t.d == other.d
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalJSON implements the json.Marshaler interface. If the time is set it
|
// MarshalJSON implements the json.Marshaler interface. If the time is set it
|
||||||
// will return the time in RFC 3339 format if not it will return the duration
|
// will return the time in RFC 3339 format if not it will return the duration
|
||||||
// string.
|
// string.
|
||||||
|
|
114
authority/ssh.go
Normal file
114
authority/ssh.go
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
package authority
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateSSHPublicKeyID(key ssh.PublicKey) string {
|
||||||
|
sum := sha256.Sum256(key.Marshal())
|
||||||
|
return strings.ToLower(hex.EncodeToString(sum[:]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSH creates a signed SSH certificate with the given public key and options.
|
||||||
|
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
|
var mods []provisioner.SSHCertificateModifier
|
||||||
|
var validators []provisioner.SSHCertificateValidator
|
||||||
|
|
||||||
|
for _, op := range signOpts {
|
||||||
|
switch o := op.(type) {
|
||||||
|
// modify the ssh.Certificate
|
||||||
|
case provisioner.SSHCertificateModifier:
|
||||||
|
mods = append(mods, o)
|
||||||
|
// modify the ssh.Certificate given the SSHOptions
|
||||||
|
case provisioner.SSHCertificateOptionModifier:
|
||||||
|
mods = append(mods, o.Option(opts))
|
||||||
|
// validate the ssh.Certificate
|
||||||
|
case provisioner.SSHCertificateValidator:
|
||||||
|
validators = append(validators, o)
|
||||||
|
// validate the given SSHOptions
|
||||||
|
case provisioner.SSHCertificateOptionsValidator:
|
||||||
|
if err := o.Valid(opts); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusUnauthorized}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Errorf("signSSH: invalid extra option type %T", o),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, err := randutil.ASCII(32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var serial uint64
|
||||||
|
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error reading random number")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build base certificate with the key and some random values
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte(nonce),
|
||||||
|
Key: key,
|
||||||
|
Serial: serial,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use opts to modify the certificate
|
||||||
|
if err := opts.Modify(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use provisioner modifiers
|
||||||
|
for _, m := range mods {
|
||||||
|
if err := m.Modify(cert); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get signer from authority keys
|
||||||
|
var signer ssh.Signer
|
||||||
|
switch cert.CertType {
|
||||||
|
case ssh.UserCert:
|
||||||
|
signer, err = ssh.NewSignerFromSigner(a.sshCAUserCertSignKey)
|
||||||
|
case ssh.HostCert:
|
||||||
|
signer, err = ssh.NewSignerFromSigner(a.sshCAHostCertSignKey)
|
||||||
|
default:
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Errorf("unexpected ssh certificate type: %d", cert.CertType),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cert.SignatureKey = signer.PublicKey()
|
||||||
|
|
||||||
|
// Get bytes for signing trailing the signature length.
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
|
||||||
|
// Sign the certificate
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
|
||||||
|
// User provisioners validators
|
||||||
|
for _, v := range validators {
|
||||||
|
if err := v.Valid(cert); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusUnauthorized}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
22
ca/client.go
22
ca/client.go
|
@ -373,6 +373,28 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
|
||||||
return &sign, nil
|
return &sign, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignSSH performs the SSH certificate sign request to the CA and returns the
|
||||||
|
// api.SignSSHResponse struct.
|
||||||
|
func (c *Client) SignSSH(req *api.SignSSHRequest) (*api.SignSSHResponse, error) {
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
|
}
|
||||||
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign-ssh"})
|
||||||
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readError(resp.Body)
|
||||||
|
}
|
||||||
|
var sign api.SignSSHResponse
|
||||||
|
if err := readJSON(resp.Body, &sign); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||||
|
}
|
||||||
|
return &sign, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Renew performs the renew request to the CA and returns the api.SignResponse
|
// Renew performs the renew request to the CA and returns the api.SignResponse
|
||||||
// struct.
|
// struct.
|
||||||
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
||||||
|
|
Loading…
Reference in a new issue