forked from TrueCloudLab/certificates
commit
7726f5ec75
41 changed files with 2710 additions and 120 deletions
8
Gopkg.lock
generated
8
Gopkg.lock
generated
|
@ -262,15 +262,20 @@
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
digest = "1:5dd7da6df07f42194cb25d162b4b89664ed7b08d7d4334f6a288393d54b095ce"
|
digest = "1:afc49fe39c8c591fc2c8ddc73adc4c69e67125dde6c58e24c91b3b0cf78602be"
|
||||||
name = "golang.org/x/crypto"
|
name = "golang.org/x/crypto"
|
||||||
packages = [
|
packages = [
|
||||||
"cryptobyte",
|
"cryptobyte",
|
||||||
"cryptobyte/asn1",
|
"cryptobyte/asn1",
|
||||||
|
"curve25519",
|
||||||
"ed25519",
|
"ed25519",
|
||||||
"ed25519/internal/edwards25519",
|
"ed25519/internal/edwards25519",
|
||||||
|
"internal/chacha20",
|
||||||
|
"internal/subtle",
|
||||||
"ocsp",
|
"ocsp",
|
||||||
"pbkdf2",
|
"pbkdf2",
|
||||||
|
"poly1305",
|
||||||
|
"ssh",
|
||||||
"ssh/terminal",
|
"ssh/terminal",
|
||||||
]
|
]
|
||||||
pruneopts = "UT"
|
pruneopts = "UT"
|
||||||
|
@ -394,6 +399,7 @@
|
||||||
"github.com/urfave/cli",
|
"github.com/urfave/cli",
|
||||||
"golang.org/x/crypto/ed25519",
|
"golang.org/x/crypto/ed25519",
|
||||||
"golang.org/x/crypto/ocsp",
|
"golang.org/x/crypto/ocsp",
|
||||||
|
"golang.org/x/crypto/ssh",
|
||||||
"golang.org/x/net/http2",
|
"golang.org/x/net/http2",
|
||||||
"gopkg.in/square/go-jose.v2",
|
"gopkg.in/square/go-jose.v2",
|
||||||
"gopkg.in/square/go-jose.v2/jwt",
|
"gopkg.in/square/go-jose.v2/jwt",
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/dsa"
|
"crypto/dsa"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
@ -26,9 +27,10 @@ import (
|
||||||
|
|
||||||
// Authority is the interface implemented by a CA authority.
|
// Authority is the interface implemented by a CA authority.
|
||||||
type Authority interface {
|
type Authority interface {
|
||||||
|
SSHAuthority
|
||||||
// NOTE: Authorize will be deprecated in future releases. Please use the
|
// NOTE: Authorize will be deprecated in future releases. Please use the
|
||||||
// context specific Authoirize[Sign|Revoke|etc.] methods.
|
// context specific Authorize[Sign|Revoke|etc.] methods.
|
||||||
Authorize(ott string) ([]provisioner.SignOption, error)
|
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||||
GetTLSOptions() *tlsutil.TLSOptions
|
GetTLSOptions() *tlsutil.TLSOptions
|
||||||
Root(shasum string) (*x509.Certificate, error)
|
Root(shasum string) (*x509.Certificate, error)
|
||||||
|
@ -249,6 +251,8 @@ func (h *caHandler) Route(r Router) {
|
||||||
r.MethodFunc("GET", "/federation", h.Federation)
|
r.MethodFunc("GET", "/federation", h.Federation)
|
||||||
// For compatibility with old code:
|
// For compatibility with old code:
|
||||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||||
|
// SSH CA
|
||||||
|
r.MethodFunc("POST", "/sign-ssh", h.SignSSH)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Health is an HTTP handler that returns the status of the server.
|
// Health is an HTTP handler that returns the status of the server.
|
||||||
|
|
|
@ -31,6 +31,7 @@ import (
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"github.com/smallstep/cli/crypto/tlsutil"
|
"github.com/smallstep/cli/crypto/tlsutil"
|
||||||
"github.com/smallstep/cli/jose"
|
"github.com/smallstep/cli/jose"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -424,7 +425,7 @@ type mockProvisioner struct {
|
||||||
getEncryptedKey func() (string, string, bool)
|
getEncryptedKey func() (string, string, bool)
|
||||||
init func(provisioner.Config) error
|
init func(provisioner.Config) error
|
||||||
authorizeRevoke func(ott string) error
|
authorizeRevoke func(ott string) error
|
||||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
authorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||||
authorizeRenewal func(*x509.Certificate) error
|
authorizeRenewal func(*x509.Certificate) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -480,9 +481,9 @@ func (m *mockProvisioner) AuthorizeRevoke(ott string) error {
|
||||||
return m.err
|
return m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockProvisioner) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
func (m *mockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
if m.authorizeSign != nil {
|
if m.authorizeSign != nil {
|
||||||
return m.authorizeSign(ott)
|
return m.authorizeSign(ctx, ott)
|
||||||
}
|
}
|
||||||
return m.ret1.([]provisioner.SignOption), m.err
|
return m.ret1.([]provisioner.SignOption), m.err
|
||||||
}
|
}
|
||||||
|
@ -501,6 +502,8 @@ type mockAuthority struct {
|
||||||
getTLSOptions func() *tlsutil.TLSOptions
|
getTLSOptions func() *tlsutil.TLSOptions
|
||||||
root func(shasum string) (*x509.Certificate, error)
|
root func(shasum string) (*x509.Certificate, error)
|
||||||
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||||
|
signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||||
|
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||||
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||||
|
@ -511,7 +514,7 @@ type mockAuthority struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove once Authorize is deprecated.
|
// TODO: remove once Authorize is deprecated.
|
||||||
func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) {
|
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
return m.AuthorizeSign(ott)
|
return m.AuthorizeSign(ott)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -543,6 +546,20 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Optio
|
||||||
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
|
if m.signSSH != nil {
|
||||||
|
return m.signSSH(key, opts, signOpts...)
|
||||||
|
}
|
||||||
|
return m.ret1.(*ssh.Certificate), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
|
if m.signSSHAddUser != nil {
|
||||||
|
return m.signSSHAddUser(key, cert)
|
||||||
|
}
|
||||||
|
return m.ret1.(*ssh.Certificate), m.err
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) {
|
func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) {
|
||||||
if m.renew != nil {
|
if m.renew != nil {
|
||||||
return m.renew(cert)
|
return m.renew(cert)
|
||||||
|
|
159
api/ssh.go
Normal file
159
api/ssh.go
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSHAuthority is the interface implemented by a SSH CA authority.
|
||||||
|
type SSHAuthority interface {
|
||||||
|
SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||||
|
SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSHRequest is the request body of an SSH certificate request.
|
||||||
|
type SignSSHRequest struct {
|
||||||
|
PublicKey []byte `json:"publicKey"` //base64 encoded
|
||||||
|
OTT string `json:"ott"`
|
||||||
|
CertType string `json:"certType,omitempty"`
|
||||||
|
Principals []string `json:"principals,omitempty"`
|
||||||
|
ValidAfter TimeDuration `json:"validAfter,omitempty"`
|
||||||
|
ValidBefore TimeDuration `json:"validBefore,omitempty"`
|
||||||
|
AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSHResponse is the response object that returns the SSH certificate.
|
||||||
|
type SignSSHResponse struct {
|
||||||
|
Certificate SSHCertificate `json:"crt"`
|
||||||
|
AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificate represents the response SSH certificate.
|
||||||
|
type SSHCertificate struct {
|
||||||
|
*ssh.Certificate `json:"omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface. Returns a quoted,
|
||||||
|
// base64 encoded, openssh wire format version of the certificate.
|
||||||
|
func (c SSHCertificate) MarshalJSON() ([]byte, error) {
|
||||||
|
if c.Certificate == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
s := base64.StdEncoding.EncodeToString(c.Certificate.Marshal())
|
||||||
|
return []byte(`"` + s + `"`), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface. The certificate is
|
||||||
|
// expected to be a quoted, base64 encoded, openssh wire formatted block of bytes.
|
||||||
|
func (c *SSHCertificate) UnmarshalJSON(data []byte) error {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return errors.Wrap(err, "error decoding certificate")
|
||||||
|
}
|
||||||
|
if s == "" {
|
||||||
|
c.Certificate = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
certData, err := base64.StdEncoding.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error decoding ssh certificate")
|
||||||
|
}
|
||||||
|
pub, err := ssh.ParsePublicKey(certData)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error parsing ssh certificate")
|
||||||
|
}
|
||||||
|
cert, ok := pub.(*ssh.Certificate)
|
||||||
|
if !ok {
|
||||||
|
return errors.Errorf("error decoding ssh certificate: %T is not an *ssh.Certificate", pub)
|
||||||
|
}
|
||||||
|
c.Certificate = cert
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the SignSSHRequest.
|
||||||
|
func (s *SignSSHRequest) Validate() error {
|
||||||
|
switch {
|
||||||
|
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
|
||||||
|
return errors.Errorf("unknown certType %s", s.CertType)
|
||||||
|
case len(s.PublicKey) == 0:
|
||||||
|
return errors.New("missing or empty publicKey")
|
||||||
|
case len(s.OTT) == 0:
|
||||||
|
return errors.New("missing or empty ott")
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSH is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||||
|
// (ott) from the body and creates a new SSH certificate with the information in
|
||||||
|
// the request.
|
||||||
|
func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var body SignSSHRequest
|
||||||
|
if err := ReadJSON(r.Body, &body); err != nil {
|
||||||
|
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logOtt(w, body.OTT)
|
||||||
|
if err := body.Validate(); err != nil {
|
||||||
|
WriteError(w, BadRequest(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var addUserPublicKey ssh.PublicKey
|
||||||
|
if body.AddUserPublicKey != nil {
|
||||||
|
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := provisioner.SSHOptions{
|
||||||
|
CertType: body.CertType,
|
||||||
|
Principals: body.Principals,
|
||||||
|
ValidBefore: body.ValidBefore,
|
||||||
|
ValidAfter: body.ValidAfter,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod)
|
||||||
|
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, Unauthorized(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, Forbidden(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var addUserCertificate *SSHCertificate
|
||||||
|
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
|
||||||
|
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, Forbidden(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addUserCertificate = &SSHCertificate{addUserCert}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
JSON(w, &SignSSHResponse{
|
||||||
|
Certificate: SSHCertificate{cert},
|
||||||
|
AddUserCertificate: addUserCertificate,
|
||||||
|
})
|
||||||
|
}
|
327
api/ssh_test.go
Normal file
327
api/ssh_test.go
Normal file
|
@ -0,0 +1,327 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/logging"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sshSignerKey = mustKey()
|
||||||
|
sshUserKey = mustKey()
|
||||||
|
sshHostKey = mustKey()
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustKey() *ecdsa.PrivateKey {
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return priv
|
||||||
|
}
|
||||||
|
|
||||||
|
func signSSHCertificate(cert *ssh.Certificate) error {
|
||||||
|
signerKey, err := ssh.NewPublicKey(sshSignerKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
signer, err := ssh.NewSignerFromSigner(sshSignerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.SignatureKey = signerKey
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSignedUserCertificate() (*ssh.Certificate, error) {
|
||||||
|
key, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t := time.Now()
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte("1234567890"),
|
||||||
|
Key: key,
|
||||||
|
Serial: 1234567890,
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
KeyId: "user@localhost",
|
||||||
|
ValidPrincipals: []string{"user"},
|
||||||
|
ValidAfter: uint64(t.Unix()),
|
||||||
|
ValidBefore: uint64(t.Add(time.Hour).Unix()),
|
||||||
|
Permissions: ssh.Permissions{
|
||||||
|
CriticalOptions: map[string]string{},
|
||||||
|
Extensions: map[string]string{
|
||||||
|
"permit-X11-forwarding": "",
|
||||||
|
"permit-agent-forwarding": "",
|
||||||
|
"permit-port-forwarding": "",
|
||||||
|
"permit-pty": "",
|
||||||
|
"permit-user-rc": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Reserved: []byte{},
|
||||||
|
}
|
||||||
|
if err := signSSHCertificate(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSignedHostCertificate() (*ssh.Certificate, error) {
|
||||||
|
key, err := ssh.NewPublicKey(sshHostKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t := time.Now()
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte("1234567890"),
|
||||||
|
Key: key,
|
||||||
|
Serial: 1234567890,
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
KeyId: "internal.smallstep.com",
|
||||||
|
ValidPrincipals: []string{"internal.smallstep.com"},
|
||||||
|
ValidAfter: uint64(t.Unix()),
|
||||||
|
ValidBefore: uint64(t.Add(time.Hour).Unix()),
|
||||||
|
Permissions: ssh.Permissions{
|
||||||
|
CriticalOptions: map[string]string{},
|
||||||
|
Extensions: map[string]string{},
|
||||||
|
},
|
||||||
|
Reserved: []byte{},
|
||||||
|
}
|
||||||
|
if err := signSSHCertificate(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCertificate_MarshalJSON(t *testing.T) {
|
||||||
|
user, err := getSignedUserCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
host, err := getSignedHostCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
Certificate *ssh.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"nil", fields{Certificate: nil}, []byte("null"), false},
|
||||||
|
{"user", fields{Certificate: user}, []byte(`"` + userB64 + `"`), false},
|
||||||
|
{"user", fields{Certificate: host}, []byte(`"` + hostB64 + `"`), false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := SSHCertificate{
|
||||||
|
Certificate: tt.fields.Certificate,
|
||||||
|
}
|
||||||
|
got, err := c.MarshalJSON()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SSHCertificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SSHCertificate.MarshalJSON() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHCertificate_UnmarshalJSON(t *testing.T) {
|
||||||
|
user, err := getSignedUserCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
host, err := getSignedHostCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||||
|
keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal())
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *ssh.Certificate
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"null", args{[]byte(`null`)}, nil, false},
|
||||||
|
{"empty", args{[]byte(`""`)}, nil, false},
|
||||||
|
{"user", args{[]byte(`"` + userB64 + `"`)}, user, false},
|
||||||
|
{"host", args{[]byte(`"` + hostB64 + `"`)}, host, false},
|
||||||
|
{"bad-string", args{[]byte(userB64)}, nil, true},
|
||||||
|
{"bad-base64", args{[]byte(`"this-is-not-base64"`)}, nil, true},
|
||||||
|
{"bad-key", args{[]byte(`"bm90LWEta2V5"`)}, nil, true},
|
||||||
|
{"bat-cert", args{[]byte(`"` + keyB64 + `"`)}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &SSHCertificate{}
|
||||||
|
if err := c.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SSHCertificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(tt.want, c.Certificate) {
|
||||||
|
t.Errorf("SSHCertificate.UnmarshalJSON() got = %v, want %v\n", c.Certificate, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignSSHRequest_Validate(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
PublicKey []byte
|
||||||
|
OTT string
|
||||||
|
CertType string
|
||||||
|
Principals []string
|
||||||
|
ValidAfter TimeDuration
|
||||||
|
ValidBefore TimeDuration
|
||||||
|
AddUserPublicKey []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||||
|
{"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||||
|
{"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||||
|
{"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||||
|
{"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||||
|
{"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||||
|
{"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s := &SignSSHRequest{
|
||||||
|
PublicKey: tt.fields.PublicKey,
|
||||||
|
OTT: tt.fields.OTT,
|
||||||
|
CertType: tt.fields.CertType,
|
||||||
|
Principals: tt.fields.Principals,
|
||||||
|
ValidAfter: tt.fields.ValidAfter,
|
||||||
|
ValidBefore: tt.fields.ValidBefore,
|
||||||
|
AddUserPublicKey: tt.fields.AddUserPublicKey,
|
||||||
|
}
|
||||||
|
if err := s.Validate(); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_caHandler_SignSSH(t *testing.T) {
|
||||||
|
user, err := getSignedUserCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
host, err := getSignedHostCertificate()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||||
|
|
||||||
|
userReq, err := json.Marshal(SignSSHRequest{
|
||||||
|
PublicKey: user.Key.Marshal(),
|
||||||
|
OTT: "ott",
|
||||||
|
})
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
hostReq, err := json.Marshal(SignSSHRequest{
|
||||||
|
PublicKey: host.Key.Marshal(),
|
||||||
|
OTT: "ott",
|
||||||
|
})
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
userAddReq, err := json.Marshal(SignSSHRequest{
|
||||||
|
PublicKey: user.Key.Marshal(),
|
||||||
|
OTT: "ott",
|
||||||
|
AddUserPublicKey: user.Key.Marshal(),
|
||||||
|
})
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
Authority Authority
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
r *http.Request
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req []byte
|
||||||
|
authErr error
|
||||||
|
signCert *ssh.Certificate
|
||||||
|
signErr error
|
||||||
|
addUserCert *ssh.Certificate
|
||||||
|
addUserErr error
|
||||||
|
body []byte
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"ok-user", userReq, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated},
|
||||||
|
{"ok-host", hostReq, nil, host, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated},
|
||||||
|
{"ok-user-add", userAddReq, nil, user, nil, user, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated},
|
||||||
|
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusUnauthorized},
|
||||||
|
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
|
||||||
|
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := New(&mockAuthority{
|
||||||
|
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||||
|
return []provisioner.SignOption{}, tt.authErr
|
||||||
|
},
|
||||||
|
signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
|
return tt.signCert, tt.signErr
|
||||||
|
},
|
||||||
|
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
|
return tt.addUserCert, tt.addUserErr
|
||||||
|
},
|
||||||
|
}).(*caHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "http://example.com/sign-ssh", bytes.NewReader(tt.req))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SignSSH(logging.NewResponseLogger(w), req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("caHandler.Root unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||||
|
t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,12 +1,14 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/cli/crypto/pemutil"
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
|
@ -20,6 +22,8 @@ type Authority struct {
|
||||||
config *Config
|
config *Config
|
||||||
rootX509Certs []*x509.Certificate
|
rootX509Certs []*x509.Certificate
|
||||||
intermediateIdentity *x509util.Identity
|
intermediateIdentity *x509util.Identity
|
||||||
|
sshCAUserCertSignKey crypto.Signer
|
||||||
|
sshCAHostCertSignKey crypto.Signer
|
||||||
validateOnce bool
|
validateOnce bool
|
||||||
certificates *sync.Map
|
certificates *sync.Map
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
|
@ -117,6 +121,22 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Decrypt and load SSH keys
|
||||||
|
if a.config.SSH != nil {
|
||||||
|
if a.config.SSH.HostKey != "" {
|
||||||
|
a.sshCAHostCertSignKey, err = parseCryptoSigner(a.config.SSH.HostKey, a.config.Password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if a.config.SSH.UserKey != "" {
|
||||||
|
a.sshCAUserCertSignKey, err = parseCryptoSigner(a.config.SSH.UserKey, a.config.Password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Store all the provisioners
|
// Store all the provisioners
|
||||||
for _, p := range a.config.AuthorityConfig.Provisioners {
|
for _, p := range a.config.AuthorityConfig.Provisioners {
|
||||||
if err := a.provisioners.Store(p); err != nil {
|
if err := a.provisioners.Store(p); err != nil {
|
||||||
|
@ -143,3 +163,19 @@ func (a *Authority) GetDatabase() db.AuthDB {
|
||||||
func (a *Authority) Shutdown() error {
|
func (a *Authority) Shutdown() error {
|
||||||
return a.db.Shutdown()
|
return a.db.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseCryptoSigner(filename, password string) (crypto.Signer, error) {
|
||||||
|
var opts []pemutil.Options
|
||||||
|
if password != "" {
|
||||||
|
opts = append(opts, pemutil.WithPassword([]byte(password)))
|
||||||
|
}
|
||||||
|
key, err := pemutil.Read(filename, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
signer, ok := key.(crypto.Signer)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("key %s of type %T cannot be used for signing operations", filename, key)
|
||||||
|
}
|
||||||
|
return signer, nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -72,33 +73,51 @@ func (a *Authority) authorizeToken(ott string) (provisioner.Interface, error) {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authorize is a passthrough to AuthorizeSign.
|
// Authorize grabs the method from the context and authorizes a signature
|
||||||
// NOTE: Authorize will be deprecated in a future release. Please use the
|
// request by validating the one-time-token.
|
||||||
// context specific Authorize[Sign|Revoke|etc.] going forwards.
|
func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
|
var errContext = apiCtx{"ott": ott}
|
||||||
return a.AuthorizeSign(ott)
|
switch m := provisioner.MethodFromContext(ctx); m {
|
||||||
|
case provisioner.SignMethod:
|
||||||
|
return a.authorizeSign(ctx, ott)
|
||||||
|
case provisioner.SignSSHMethod:
|
||||||
|
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
|
||||||
|
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext}
|
||||||
|
}
|
||||||
|
return a.authorizeSign(ctx, ott)
|
||||||
|
case provisioner.RevokeMethod:
|
||||||
|
return nil, &apiError{errors.New("authorize: revoke method is not supported"), http.StatusInternalServerError, errContext}
|
||||||
|
default:
|
||||||
|
return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign authorizes a signature request by validating and authenticating
|
// authorizeSign loads the provisioner from the token, checks that it has not
|
||||||
// a OTT that must be sent w/ the request.
|
// been used again and calls the provisioner AuthorizeSign method. Returns a
|
||||||
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
// list of methods to apply to the signing flow.
|
||||||
var errContext = context{"ott": ott}
|
func (a *Authority) authorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
|
var errContext = apiCtx{"ott": ott}
|
||||||
p, err := a.authorizeToken(ott)
|
p, err := a.authorizeToken(ott)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
|
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
|
||||||
}
|
}
|
||||||
|
opts, err := p.AuthorizeSign(ctx, ott)
|
||||||
// Call the provisioner AuthorizeSign method to apply provisioner specific
|
|
||||||
// auth claims and get the signing options.
|
|
||||||
opts, err := p.AuthorizeSign(ott)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
|
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
|
||||||
}
|
}
|
||||||
|
|
||||||
return opts, nil
|
return opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthorizeSign authorizes a signature request by validating and authenticating
|
||||||
|
// a OTT that must be sent w/ the request.
|
||||||
|
//
|
||||||
|
// NOTE: This method is deprecated and should not be used. We make it available
|
||||||
|
// in the short term os as not to break existing clients.
|
||||||
|
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||||
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||||
|
return a.Authorize(ctx, ott)
|
||||||
|
}
|
||||||
|
|
||||||
// authorizeRevoke authorizes a revocation request by validating and authenticating
|
// authorizeRevoke authorizes a revocation request by validating and authenticating
|
||||||
// the RevokeOptions POSTed with the request.
|
// the RevokeOptions POSTed with the request.
|
||||||
// Returns a tuple of the provisioner ID and error, if one occurred.
|
// Returns a tuple of the provisioner ID and error, if one occurred.
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/cli/crypto/pemutil"
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
|
@ -72,7 +75,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: "foo",
|
ott: "foo",
|
||||||
err: &apiError{errors.New("authorizeToken: error parsing token"),
|
err: &apiError{errors.New("authorizeToken: error parsing token"),
|
||||||
http.StatusUnauthorized, context{"ott": "foo"}},
|
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/prehistoric-token": func(t *testing.T) *authorizeTest {
|
"fail/prehistoric-token": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -91,7 +94,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"),
|
err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/provisioner-not-found": func(t *testing.T) *authorizeTest {
|
"fail/provisioner-not-found": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -113,7 +116,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"),
|
err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/simpledb": func(t *testing.T) *authorizeTest {
|
"ok/simpledb": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -150,7 +153,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
||||||
auth: _a,
|
auth: _a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorizeToken: token already used"),
|
err: &apiError{errors.New("authorizeToken: token already used"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/mockNoSQLDB": func(t *testing.T) *authorizeTest {
|
"ok/mockNoSQLDB": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -198,7 +201,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
||||||
auth: _a,
|
auth: _a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorizeToken: failed when checking if token already used: force"),
|
err: &apiError{errors.New("authorizeToken: failed when checking if token already used: force"),
|
||||||
http.StatusInternalServerError, context{"ott": raw}},
|
http.StatusInternalServerError, apiCtx{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest {
|
"fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -223,7 +226,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
||||||
auth: _a,
|
auth: _a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorizeToken: token already used"),
|
err: &apiError{errors.New("authorizeToken: token already used"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -388,7 +391,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: "foo",
|
ott: "foo",
|
||||||
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
|
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
|
||||||
http.StatusUnauthorized, context{"ott": "foo"}},
|
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
|
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -406,7 +409,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
|
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) *authorizeTest {
|
"ok": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -480,7 +483,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: "foo",
|
ott: "foo",
|
||||||
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
|
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
|
||||||
http.StatusUnauthorized, context{"ott": "foo"}},
|
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
|
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -498,7 +501,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
|
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) *authorizeTest {
|
"ok": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -522,8 +525,8 @@ func TestAuthority_Authorize(t *testing.T) {
|
||||||
for name, genTestCase := range tests {
|
for name, genTestCase := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := genTestCase(t)
|
tc := genTestCase(t)
|
||||||
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||||
got, err := tc.auth.Authorize(tc.ott)
|
got, err := tc.auth.Authorize(ctx, tc.ott)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
|
@ -573,7 +576,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
crt: fooCrt,
|
crt: fooCrt,
|
||||||
err: &apiError{errors.New("renew: force"),
|
err: &apiError{errors.New("renew: force"),
|
||||||
http.StatusInternalServerError, context{"serialNumber": "102012593071130646873265215610956555026"}},
|
http.StatusInternalServerError, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/revoked": func(t *testing.T) *authorizeTest {
|
"fail/revoked": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -587,7 +590,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
crt: fooCrt,
|
crt: fooCrt,
|
||||||
err: &apiError{errors.New("renew: certificate has been revoked"),
|
err: &apiError{errors.New("renew: certificate has been revoked"),
|
||||||
http.StatusUnauthorized, context{"serialNumber": "102012593071130646873265215610956555026"}},
|
http.StatusUnauthorized, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/load-provisioner": func(t *testing.T) *authorizeTest {
|
"fail/load-provisioner": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -601,7 +604,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
crt: otherCrt,
|
crt: otherCrt,
|
||||||
err: &apiError{errors.New("renew: provisioner not found"),
|
err: &apiError{errors.New("renew: provisioner not found"),
|
||||||
http.StatusUnauthorized, context{"serialNumber": "41633491264736369593451462439668497527"}},
|
http.StatusUnauthorized, apiCtx{"serialNumber": "41633491264736369593451462439668497527"}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest {
|
"fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest {
|
||||||
|
@ -616,7 +619,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
|
||||||
auth: a,
|
auth: a,
|
||||||
crt: renewDisabledCrt,
|
crt: renewDisabledCrt,
|
||||||
err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
|
err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
|
||||||
http.StatusUnauthorized, context{"serialNumber": "119772236532068856521070735128919532568"}},
|
http.StatusUnauthorized, apiCtx{"serialNumber": "119772236532068856521070735128919532568"}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) *authorizeTest {
|
"ok": func(t *testing.T) *authorizeTest {
|
||||||
|
|
|
@ -28,11 +28,19 @@ var (
|
||||||
Renegotiation: false,
|
Renegotiation: false,
|
||||||
}
|
}
|
||||||
defaultDisableRenewal = false
|
defaultDisableRenewal = false
|
||||||
|
defaultEnableSSHCA = false
|
||||||
globalProvisionerClaims = provisioner.Claims{
|
globalProvisionerClaims = provisioner.Claims{
|
||||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
DisableRenewal: &defaultDisableRenewal,
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||||
|
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
|
DefaultUserSSHDur: &provisioner.Duration{Duration: 4 * time.Hour},
|
||||||
|
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||||
|
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||||
|
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||||
|
EnableSSHCA: &defaultEnableSSHCA,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,6 +52,7 @@ type Config struct {
|
||||||
IntermediateKey string `json:"key"`
|
IntermediateKey string `json:"key"`
|
||||||
Address string `json:"address"`
|
Address string `json:"address"`
|
||||||
DNSNames []string `json:"dnsNames"`
|
DNSNames []string `json:"dnsNames"`
|
||||||
|
SSH *SSHConfig `json:"ssh,omitempty"`
|
||||||
Logger json.RawMessage `json:"logger,omitempty"`
|
Logger json.RawMessage `json:"logger,omitempty"`
|
||||||
DB *db.Config `json:"db,omitempty"`
|
DB *db.Config `json:"db,omitempty"`
|
||||||
Monitoring json.RawMessage `json:"monitoring,omitempty"`
|
Monitoring json.RawMessage `json:"monitoring,omitempty"`
|
||||||
|
@ -92,6 +101,14 @@ func (c *AuthConfig) Validate(audiences provisioner.Audiences) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSHConfig contains the user and host keys.
|
||||||
|
type SSHConfig struct {
|
||||||
|
HostKey string `json:"hostKey"`
|
||||||
|
UserKey string `json:"userKey"`
|
||||||
|
AddUserPrincipal string `json:"addUserPrincipal"`
|
||||||
|
AddUserCommand string `json:"addUserCommand"`
|
||||||
|
}
|
||||||
|
|
||||||
// LoadConfiguration parses the given filename in JSON format and returns the
|
// LoadConfiguration parses the given filename in JSON format and returns the
|
||||||
// configuration struct.
|
// configuration struct.
|
||||||
func LoadConfiguration(filename string) (*Config, error) {
|
func LoadConfiguration(filename string) (*Config, error) {
|
||||||
|
|
|
@ -4,13 +4,13 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type context map[string]interface{}
|
type apiCtx map[string]interface{}
|
||||||
|
|
||||||
// Error implements the api.Error interface and adds context to error messages.
|
// Error implements the api.Error interface and adds context to error messages.
|
||||||
type apiError struct {
|
type apiError struct {
|
||||||
err error
|
err error
|
||||||
code int
|
code int
|
||||||
context context
|
context apiCtx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cause implements the errors.Causer interface and returns the original error.
|
// Cause implements the errors.Causer interface and returns the original error.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -266,13 +267,21 @@ func (p *AWS) Init(config Config) (err error) {
|
||||||
|
|
||||||
// AuthorizeSign validates the given token and returns the sign options that
|
// AuthorizeSign validates the given token and returns the sign options that
|
||||||
// will be used on certificate creation.
|
// will be used on certificate creation.
|
||||||
func (p *AWS) AuthorizeSign(token string) ([]SignOption, error) {
|
func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
payload, err := p.authorizeToken(token)
|
payload, err := p.authorizeToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
doc := payload.document
|
|
||||||
|
|
||||||
|
// Check for the sign ssh method, default to sign X.509
|
||||||
|
if m := MethodFromContext(ctx); m == SignSSHMethod {
|
||||||
|
if p.claimer.IsSSHCAEnabled() == false {
|
||||||
|
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||||
|
}
|
||||||
|
return p.authorizeSSHSign(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
doc := payload.document
|
||||||
// Enforce known CN and default DNS and IP if configured.
|
// Enforce known CN and default DNS and IP if configured.
|
||||||
// By default we'll accept the CN and SANs in the CSR.
|
// By default we'll accept the CN and SANs in the CSR.
|
||||||
// There's no way to trust them other than TOFU.
|
// There's no way to trust them other than TOFU.
|
||||||
|
@ -433,3 +442,35 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
||||||
payload.document = doc
|
payload.document = doc
|
||||||
return &payload, nil
|
return &payload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
|
func (p *AWS) authorizeSSHSign(claims *awsPayload) ([]SignOption, error) {
|
||||||
|
doc := claims.document
|
||||||
|
|
||||||
|
signOptions := []SignOption{
|
||||||
|
// set the key id to the token subject
|
||||||
|
sshCertificateKeyIDModifier(claims.Subject),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to host + known IPs/hostnames
|
||||||
|
defaults := SSHOptions{
|
||||||
|
CertType: SSHHostCert,
|
||||||
|
Principals: []string{
|
||||||
|
doc.PrivateIP,
|
||||||
|
fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Validate user options
|
||||||
|
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||||
|
// Set defaults if not given as user options
|
||||||
|
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||||
|
|
||||||
|
return append(signOptions,
|
||||||
|
// set the default extensions
|
||||||
|
&sshDefaultExtensionModifier{},
|
||||||
|
// checks the validity bounds, and set the validity if has not been set
|
||||||
|
&sshCertificateValidityModifier{p.claimer},
|
||||||
|
// require all the fields in the SSH certificate
|
||||||
|
&sshCertificateDefaultValidator{},
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
@ -347,7 +349,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.aws.AuthorizeSign(tt.args.token)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
|
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -357,6 +360,84 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAWS_AuthorizeSign_SSH(t *testing.T) {
|
||||||
|
tm, fn := mockNow()
|
||||||
|
defer fn()
|
||||||
|
|
||||||
|
p1, srv, err := generateAWSWithServer()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
key, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
signer, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||||
|
expectedHostOptions := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
expectedHostOptionsIP := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"127.0.0.1"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
expectedHostOptionsHostname := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
sshOpts SSHOptions
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
aws *AWS
|
||||||
|
args args
|
||||||
|
expected *SSHOptions
|
||||||
|
wantErr bool
|
||||||
|
wantSignErr bool
|
||||||
|
}{
|
||||||
|
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
|
||||||
|
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||||
|
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false},
|
||||||
|
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}}, expectedHostOptionsIP, false, false},
|
||||||
|
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptionsHostname, false, false},
|
||||||
|
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false},
|
||||||
|
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
|
||||||
|
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
|
||||||
|
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}}, nil, false, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||||
|
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
assert.Nil(t, got)
|
||||||
|
} else if assert.NotNil(t, got) {
|
||||||
|
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||||
|
if (err != nil) != tt.wantSignErr {
|
||||||
|
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||||
|
} else {
|
||||||
|
if tt.wantSignErr {
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
func TestAWS_AuthorizeRenewal(t *testing.T) {
|
func TestAWS_AuthorizeRenewal(t *testing.T) {
|
||||||
p1, err := generateAWS()
|
p1, err := generateAWS()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
@ -209,7 +210,7 @@ func (p *Azure) Init(config Config) (err error) {
|
||||||
|
|
||||||
// AuthorizeSign validates the given token and returns the sign options that
|
// AuthorizeSign validates the given token and returns the sign options that
|
||||||
// will be used on certificate creation.
|
// will be used on certificate creation.
|
||||||
func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
jwt, err := jose.ParseSigned(token)
|
jwt, err := jose.ParseSigned(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "error parsing token")
|
return nil, errors.Wrapf(err, "error parsing token")
|
||||||
|
@ -264,6 +265,14 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for the sign ssh method, default to sign X.509
|
||||||
|
if m := MethodFromContext(ctx); m == SignSSHMethod {
|
||||||
|
if p.claimer.IsSSHCAEnabled() == false {
|
||||||
|
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||||
|
}
|
||||||
|
return p.authorizeSSHSign(claims, name)
|
||||||
|
}
|
||||||
|
|
||||||
// Enforce known common name and default DNS if configured.
|
// Enforce known common name and default DNS if configured.
|
||||||
// By default we'll accept the CN and SANs in the CSR.
|
// By default we'll accept the CN and SANs in the CSR.
|
||||||
// There's no way to trust them other than TOFU.
|
// There's no way to trust them other than TOFU.
|
||||||
|
@ -296,6 +305,33 @@ func (p *Azure) AuthorizeRevoke(token string) error {
|
||||||
return errors.New("revoke is not supported on a Azure provisioner")
|
return errors.New("revoke is not supported on a Azure provisioner")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
|
func (p *Azure) authorizeSSHSign(claims azurePayload, name string) ([]SignOption, error) {
|
||||||
|
signOptions := []SignOption{
|
||||||
|
// set the key id to the token subject
|
||||||
|
sshCertificateKeyIDModifier(name),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to host + known hostnames
|
||||||
|
defaults := SSHOptions{
|
||||||
|
CertType: SSHHostCert,
|
||||||
|
Principals: []string{name},
|
||||||
|
}
|
||||||
|
// Validate user options
|
||||||
|
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||||
|
// Set defaults if not given as user options
|
||||||
|
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||||
|
|
||||||
|
return append(signOptions,
|
||||||
|
// set the default extensions
|
||||||
|
&sshDefaultExtensionModifier{},
|
||||||
|
// checks the validity bounds, and set the validity if has not been set
|
||||||
|
&sshCertificateValidityModifier{p.claimer},
|
||||||
|
// require all the fields in the SSH certificate
|
||||||
|
&sshCertificateDefaultValidator{},
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
||||||
// assertConfig initializes the config if it has not been initialized
|
// assertConfig initializes the config if it has not been initialized
|
||||||
func (p *Azure) assertConfig() {
|
func (p *Azure) assertConfig() {
|
||||||
if p.config == nil {
|
if p.config == nil {
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
@ -295,7 +297,8 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.azure.AuthorizeSign(tt.args.token)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
|
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -305,6 +308,75 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAzure_AuthorizeSign_SSH(t *testing.T) {
|
||||||
|
tm, fn := mockNow()
|
||||||
|
defer fn()
|
||||||
|
|
||||||
|
p1, srv, err := generateAzureWithServer()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
t1, err := p1.GetIdentityToken("subject", "caURL")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
key, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
signer, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||||
|
expectedHostOptions := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"virtualMachine"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
sshOpts SSHOptions
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
azure *Azure
|
||||||
|
args args
|
||||||
|
expected *SSHOptions
|
||||||
|
wantErr bool
|
||||||
|
wantSignErr bool
|
||||||
|
}{
|
||||||
|
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
|
||||||
|
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||||
|
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false},
|
||||||
|
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false},
|
||||||
|
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
|
||||||
|
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
|
||||||
|
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}}, nil, false, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||||
|
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
assert.Nil(t, got)
|
||||||
|
} else if assert.NotNil(t, got) {
|
||||||
|
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||||
|
if (err != nil) != tt.wantSignErr {
|
||||||
|
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||||
|
} else {
|
||||||
|
if tt.wantSignErr {
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAzure_AuthorizeRenewal(t *testing.T) {
|
func TestAzure_AuthorizeRenewal(t *testing.T) {
|
||||||
p1, err := generateAzure()
|
p1, err := generateAzure()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
|
@ -8,10 +8,19 @@ import (
|
||||||
|
|
||||||
// Claims so that individual provisioners can override global claims.
|
// Claims so that individual provisioners can override global claims.
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
|
// TLS CA properties
|
||||||
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
||||||
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
||||||
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
||||||
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||||
|
// SSH CA properties
|
||||||
|
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
|
||||||
|
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
|
||||||
|
DefaultUserSSHDur *Duration `json:"defaultUserSSHCertDuration,omitempty"`
|
||||||
|
MinHostSSHDur *Duration `json:"minHostSSHCertDuration,omitempty"`
|
||||||
|
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
|
||||||
|
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
|
||||||
|
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claimer is the type that controls claims. It provides an interface around the
|
// Claimer is the type that controls claims. It provides an interface around the
|
||||||
|
@ -30,11 +39,19 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
|
||||||
// Claims returns the merge of the inner and global claims.
|
// Claims returns the merge of the inner and global claims.
|
||||||
func (c *Claimer) Claims() Claims {
|
func (c *Claimer) Claims() Claims {
|
||||||
disableRenewal := c.IsDisableRenewal()
|
disableRenewal := c.IsDisableRenewal()
|
||||||
|
enableSSHCA := c.IsSSHCAEnabled()
|
||||||
return Claims{
|
return Claims{
|
||||||
MinTLSDur: &Duration{c.MinTLSCertDuration()},
|
MinTLSDur: &Duration{c.MinTLSCertDuration()},
|
||||||
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
||||||
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
||||||
DisableRenewal: &disableRenewal,
|
DisableRenewal: &disableRenewal,
|
||||||
|
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
|
||||||
|
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
|
||||||
|
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
|
||||||
|
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
|
||||||
|
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
|
||||||
|
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
|
||||||
|
EnableSSHCA: &enableSSHCA,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +95,76 @@ func (c *Claimer) IsDisableRenewal() bool {
|
||||||
return *c.claims.DisableRenewal
|
return *c.claims.DisableRenewal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultUserSSHCertDuration returns the default SSH user cert duration for the
|
||||||
|
// provisioner. If the default is not set within the provisioner, then the
|
||||||
|
// global default from the authority configuration will be used.
|
||||||
|
func (c *Claimer) DefaultUserSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.DefaultUserSSHDur == nil {
|
||||||
|
return c.global.DefaultUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.DefaultUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinUserSSHCertDuration returns the minimum SSH user cert duration for the
|
||||||
|
// provisioner. If the minimum is not set within the provisioner, then the
|
||||||
|
// global minimum from the authority configuration will be used.
|
||||||
|
func (c *Claimer) MinUserSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.MinUserSSHDur == nil {
|
||||||
|
return c.global.MinUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.MinUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxUserSSHCertDuration returns the maximum SSH user cert duration for the
|
||||||
|
// provisioner. If the maximum is not set within the provisioner, then the
|
||||||
|
// global maximum from the authority configuration will be used.
|
||||||
|
func (c *Claimer) MaxUserSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.MaxUserSSHDur == nil {
|
||||||
|
return c.global.MaxUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.MaxUserSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHostSSHCertDuration returns the default SSH host cert duration for the
|
||||||
|
// provisioner. If the default is not set within the provisioner, then the
|
||||||
|
// global default from the authority configuration will be used.
|
||||||
|
func (c *Claimer) DefaultHostSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.DefaultHostSSHDur == nil {
|
||||||
|
return c.global.DefaultHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.DefaultHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinHostSSHCertDuration returns the minimum SSH host cert duration for the
|
||||||
|
// provisioner. If the minimum is not set within the provisioner, then the
|
||||||
|
// global minimum from the authority configuration will be used.
|
||||||
|
func (c *Claimer) MinHostSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.MinHostSSHDur == nil {
|
||||||
|
return c.global.MinHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.MinHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxHostSSHCertDuration returns the maximum SSH Host cert duration for the
|
||||||
|
// provisioner. If the maximum is not set within the provisioner, then the
|
||||||
|
// global maximum from the authority configuration will be used.
|
||||||
|
func (c *Claimer) MaxHostSSHCertDuration() time.Duration {
|
||||||
|
if c.claims == nil || c.claims.MaxHostSSHDur == nil {
|
||||||
|
return c.global.MaxHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
return c.claims.MaxHostSSHDur.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSSHCAEnabled returns if the SSH CA is enabled for the provisioner. If the
|
||||||
|
// property is not set within the provisioner, then the global value from the
|
||||||
|
// authority configuration will be used.
|
||||||
|
func (c *Claimer) IsSSHCAEnabled() bool {
|
||||||
|
if c.claims == nil || c.claims.EnableSSHCA == nil {
|
||||||
|
return *c.global.EnableSSHCA
|
||||||
|
}
|
||||||
|
return *c.claims.EnableSSHCA
|
||||||
|
}
|
||||||
|
|
||||||
// Validate validates and modifies the Claims with default values.
|
// Validate validates and modifies the Claims with default values.
|
||||||
func (c *Claimer) Validate() error {
|
func (c *Claimer) Validate() error {
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -2,6 +2,7 @@ package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
@ -205,13 +206,21 @@ func (p *GCP) Init(config Config) error {
|
||||||
|
|
||||||
// AuthorizeSign validates the given token and returns the sign options that
|
// AuthorizeSign validates the given token and returns the sign options that
|
||||||
// will be used on certificate creation.
|
// will be used on certificate creation.
|
||||||
func (p *GCP) AuthorizeSign(token string) ([]SignOption, error) {
|
func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
claims, err := p.authorizeToken(token)
|
claims, err := p.authorizeToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ce := claims.Google.ComputeEngine
|
|
||||||
|
|
||||||
|
// Check for the sign ssh method, default to sign X.509
|
||||||
|
if m := MethodFromContext(ctx); m == SignSSHMethod {
|
||||||
|
if p.claimer.IsSSHCAEnabled() == false {
|
||||||
|
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||||
|
}
|
||||||
|
return p.authorizeSSHSign(claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
ce := claims.Google.ComputeEngine
|
||||||
// Enforce known common name and default DNS if configured.
|
// Enforce known common name and default DNS if configured.
|
||||||
// By default we we'll accept the CN and SANs in the CSR.
|
// By default we we'll accept the CN and SANs in the CSR.
|
||||||
// There's no way to trust them other than TOFU.
|
// There's no way to trust them other than TOFU.
|
||||||
|
@ -345,3 +354,35 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
||||||
|
|
||||||
return &claims, nil
|
return &claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
|
func (p *GCP) authorizeSSHSign(claims *gcpPayload) ([]SignOption, error) {
|
||||||
|
ce := claims.Google.ComputeEngine
|
||||||
|
|
||||||
|
signOptions := []SignOption{
|
||||||
|
// set the key id to the token subject
|
||||||
|
sshCertificateKeyIDModifier(ce.InstanceName),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to host + known hostnames
|
||||||
|
defaults := SSHOptions{
|
||||||
|
CertType: SSHHostCert,
|
||||||
|
Principals: []string{
|
||||||
|
fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID),
|
||||||
|
fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Validate user options
|
||||||
|
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||||
|
// Set defaults if not given as user options
|
||||||
|
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||||
|
|
||||||
|
return append(signOptions,
|
||||||
|
// set the default extensions
|
||||||
|
&sshDefaultExtensionModifier{},
|
||||||
|
// checks the validity bounds, and set the validity if has not been set
|
||||||
|
&sshCertificateValidityModifier{p.claimer},
|
||||||
|
// require all the fields in the SSH certificate
|
||||||
|
&sshCertificateDefaultValidator{},
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
@ -330,7 +332,8 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.gcp.AuthorizeSign(tt.args.token)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
|
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -340,6 +343,87 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGCP_AuthorizeSign_SSH(t *testing.T) {
|
||||||
|
tm, fn := mockNow()
|
||||||
|
defer fn()
|
||||||
|
|
||||||
|
p1, err := generateGCP()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
t1, err := generateGCPToken(p1.ServiceAccounts[0],
|
||||||
|
"https://accounts.google.com", p1.GetID(),
|
||||||
|
"instance-id", "instance-name", "project-id", "zone",
|
||||||
|
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
key, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
signer, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||||
|
expectedHostOptions := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
expectedHostOptionsPrincipal1 := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"instance-name.c.project-id.internal"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
expectedHostOptionsPrincipal2 := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"instance-name.zone.c.project-id.internal"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
sshOpts SSHOptions
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
gcp *GCP
|
||||||
|
args args
|
||||||
|
expected *SSHOptions
|
||||||
|
wantErr bool
|
||||||
|
wantSignErr bool
|
||||||
|
}{
|
||||||
|
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
|
||||||
|
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||||
|
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false},
|
||||||
|
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}}, expectedHostOptionsPrincipal1, false, false},
|
||||||
|
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}}, expectedHostOptionsPrincipal2, false, false},
|
||||||
|
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false},
|
||||||
|
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
|
||||||
|
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
|
||||||
|
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}}, nil, false, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||||
|
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
assert.Nil(t, got)
|
||||||
|
} else if assert.NotNil(t, got) {
|
||||||
|
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||||
|
if (err != nil) != tt.wantSignErr {
|
||||||
|
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||||
|
} else {
|
||||||
|
if tt.wantSignErr {
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGCP_AuthorizeRenewal(t *testing.T) {
|
func TestGCP_AuthorizeRenewal(t *testing.T) {
|
||||||
p1, err := generateGCP()
|
p1, err := generateGCP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -12,7 +13,12 @@ import (
|
||||||
// jwtPayload extends jwt.Claims with step attributes.
|
// jwtPayload extends jwt.Claims with step attributes.
|
||||||
type jwtPayload struct {
|
type jwtPayload struct {
|
||||||
jose.Claims
|
jose.Claims
|
||||||
SANs []string `json:"sans,omitempty"`
|
SANs []string `json:"sans,omitempty"`
|
||||||
|
Step *stepPayload `json:"step,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type stepPayload struct {
|
||||||
|
SSH *SSHOptions `json:"ssh,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWK is the default provisioner, an entity that can sign tokens necessary for
|
// JWK is the default provisioner, an entity that can sign tokens necessary for
|
||||||
|
@ -129,11 +135,20 @@ func (p *JWK) AuthorizeRevoke(token string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign validates the given token.
|
// AuthorizeSign validates the given token.
|
||||||
func (p *JWK) AuthorizeSign(token string) ([]SignOption, error) {
|
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for SSH token
|
||||||
|
if claims.Step != nil && claims.Step.SSH != nil {
|
||||||
|
if p.claimer.IsSSHCAEnabled() == false {
|
||||||
|
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||||
|
}
|
||||||
|
return p.authorizeSSHSign(claims)
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: This is for backwards compatibility with older versions of cli
|
// NOTE: This is for backwards compatibility with older versions of cli
|
||||||
// and certificates. Older versions added the token subject as the only SAN
|
// and certificates. Older versions added the token subject as the only SAN
|
||||||
// in a CSR by default.
|
// in a CSR by default.
|
||||||
|
@ -161,3 +176,41 @@ func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
|
func (p *JWK) authorizeSSHSign(claims *jwtPayload) ([]SignOption, error) {
|
||||||
|
t := now()
|
||||||
|
opts := claims.Step.SSH
|
||||||
|
signOptions := []SignOption{
|
||||||
|
// validates user's SSHOptions with the ones in the token
|
||||||
|
sshCertificateOptionsValidator(*opts),
|
||||||
|
// set the key id to the token subject
|
||||||
|
sshCertificateKeyIDModifier(claims.Subject),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add modifiers from custom claims
|
||||||
|
if opts.CertType != "" {
|
||||||
|
signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType))
|
||||||
|
}
|
||||||
|
if len(opts.Principals) > 0 {
|
||||||
|
signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals))
|
||||||
|
}
|
||||||
|
if !opts.ValidAfter.IsZero() {
|
||||||
|
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
|
||||||
|
}
|
||||||
|
if !opts.ValidBefore.IsZero() {
|
||||||
|
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to a user certificate with no principals if not set
|
||||||
|
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert})
|
||||||
|
|
||||||
|
return append(signOptions,
|
||||||
|
// set the default extensions
|
||||||
|
&sshDefaultExtensionModifier{},
|
||||||
|
// checks the validity bounds, and set the validity if has not been set
|
||||||
|
&sshCertificateValidityModifier{p.claimer},
|
||||||
|
// require all the fields in the SSH certificate
|
||||||
|
&sshCertificateDefaultValidator{},
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -13,11 +15,19 @@ import (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultDisableRenewal = false
|
defaultDisableRenewal = false
|
||||||
|
defaultEnableSSHCA = true
|
||||||
globalProvisionerClaims = Claims{
|
globalProvisionerClaims = Claims{
|
||||||
MinTLSDur: &Duration{5 * time.Minute},
|
MinTLSDur: &Duration{5 * time.Minute},
|
||||||
MaxTLSDur: &Duration{24 * time.Hour},
|
MaxTLSDur: &Duration{24 * time.Hour},
|
||||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||||
DisableRenewal: &defaultDisableRenewal,
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||||
|
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
|
||||||
|
DefaultUserSSHDur: &Duration{Duration: 4 * time.Hour},
|
||||||
|
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||||
|
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||||
|
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||||
|
EnableSSHCA: &defaultEnableSSHCA,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -259,7 +269,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got, err := tt.prov.AuthorizeSign(tt.args.token); err != nil {
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
|
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
|
||||||
if assert.NotNil(t, tt.err) {
|
if assert.NotNil(t, tt.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -318,3 +329,201 @@ func TestJWK_AuthorizeRenewal(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJWK_AuthorizeSign_SSH(t *testing.T) {
|
||||||
|
tm, fn := mockNow()
|
||||||
|
defer fn()
|
||||||
|
|
||||||
|
p1, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
iss, aud := p1.Name, testAudiences.Sign[0]
|
||||||
|
|
||||||
|
t1, err := generateSimpleSSHUserToken(iss, aud, jwk)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
t2, err := generateSimpleSSHHostToken(iss, aud, jwk)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// invalid signature
|
||||||
|
failSig := t1[0 : len(t1)-2]
|
||||||
|
|
||||||
|
key, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
signer, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||||
|
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||||
|
expectedUserOptions := &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||||
|
}
|
||||||
|
expectedHostOptions := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"smallstep.com"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
sshOpts SSHOptions
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *JWK
|
||||||
|
args args
|
||||||
|
expected *SSHOptions
|
||||||
|
wantErr bool
|
||||||
|
wantSignErr bool
|
||||||
|
}{
|
||||||
|
{"user", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false},
|
||||||
|
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false},
|
||||||
|
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||||
|
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||||
|
{"host", p1, args{t2, SSHOptions{}}, expectedHostOptions, false, false},
|
||||||
|
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||||
|
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
|
||||||
|
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
|
||||||
|
{"fail-signature", p1, args{failSig, SSHOptions{}}, nil, true, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||||
|
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
assert.Nil(t, got)
|
||||||
|
} else if assert.NotNil(t, got) {
|
||||||
|
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||||
|
if (err != nil) != tt.wantSignErr {
|
||||||
|
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||||
|
} else {
|
||||||
|
if tt.wantSignErr {
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
|
||||||
|
tm, fn := mockNow()
|
||||||
|
defer fn()
|
||||||
|
|
||||||
|
p1, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
sub, iss, aud, iat := "subject@smallstep.com", p1.Name, testAudiences.Sign[0], time.Now()
|
||||||
|
|
||||||
|
key, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
signer, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||||
|
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||||
|
expectedUserOptions := &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||||
|
}
|
||||||
|
expectedHostOptions := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"smallstep.com"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
sub, iss, aud string
|
||||||
|
iat time.Time
|
||||||
|
tokSSHOpts *SSHOptions
|
||||||
|
userSSHOpts *SSHOptions
|
||||||
|
jwk *jose.JSONWebKey
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *JWK
|
||||||
|
args args
|
||||||
|
expected *SSHOptions
|
||||||
|
wantErr bool
|
||||||
|
wantSignErr bool
|
||||||
|
}{
|
||||||
|
{"ok-user", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, expectedUserOptions, false, false},
|
||||||
|
{"ok-host", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, &SSHOptions{}, jwk}, expectedHostOptions, false, false},
|
||||||
|
{"ok-user-opts", p1, args{sub, iss, aud, iat, &SSHOptions{}, &SSHOptions{CertType: "user", Principals: []string{"name"}}, jwk}, expectedUserOptions, false, false},
|
||||||
|
{"ok-host-opts", p1, args{sub, iss, aud, iat, &SSHOptions{}, &SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, jwk}, expectedHostOptions, false, false},
|
||||||
|
{"ok-user-mixed", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user"}, &SSHOptions{Principals: []string{"name"}}, jwk}, expectedUserOptions, false, false},
|
||||||
|
{"ok-host-mixed", p1, args{sub, iss, aud, iat, &SSHOptions{Principals: []string{"smallstep.com"}}, &SSHOptions{CertType: "host"}, jwk}, expectedHostOptions, false, false},
|
||||||
|
{"ok-user-validAfter", p1, args{sub, iss, aud, iat, &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"},
|
||||||
|
}, &SSHOptions{
|
||||||
|
ValidAfter: NewTimeDuration(tm.Add(-time.Hour)),
|
||||||
|
}, jwk}, &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(-time.Hour)), ValidBefore: NewTimeDuration(tm.Add(userDuration - time.Hour)),
|
||||||
|
}, false, false},
|
||||||
|
{"ok-user-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"},
|
||||||
|
}, &SSHOptions{
|
||||||
|
ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||||
|
}, jwk}, &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||||
|
}, false, false},
|
||||||
|
{"ok-user-validAfter-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"},
|
||||||
|
}, &SSHOptions{
|
||||||
|
ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||||
|
}, jwk}, &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||||
|
}, false, false},
|
||||||
|
{"ok-user-match", p1, args{sub, iss, aud, iat, &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)),
|
||||||
|
}, &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)),
|
||||||
|
}, jwk}, &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||||
|
}, false, false},
|
||||||
|
{"fail-certType", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{CertType: "host"}, jwk}, nil, false, true},
|
||||||
|
{"fail-principals", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{Principals: []string{"root"}}, jwk}, nil, false, true},
|
||||||
|
{"fail-validAfter", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm)}, &SSHOptions{ValidAfter: NewTimeDuration(tm.Add(time.Hour))}, jwk}, nil, false, true},
|
||||||
|
{"fail-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}, ValidBefore: NewTimeDuration(tm.Add(time.Hour))}, &SSHOptions{ValidBefore: NewTimeDuration(tm.Add(10 * time.Hour))}, jwk}, nil, false, true},
|
||||||
|
{"fail-subject", p1, args{"", iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||||
|
{"fail-issuer", p1, args{sub, "invalid", aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||||
|
{"fail-audience", p1, args{sub, iss, "invalid", iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||||
|
{"fail-expired", p1, args{sub, iss, aud, iat.Add(-6 * time.Minute), &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||||
|
{"fail-notBefore", p1, args{sub, iss, aud, iat.Add(5 * time.Minute), &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||||
|
token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
if got, err := tt.prov.AuthorizeSign(ctx, token); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("JWK.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
} else if !tt.wantErr && assert.NotNil(t, got) {
|
||||||
|
var opts SSHOptions
|
||||||
|
if tt.args.userSSHOpts != nil {
|
||||||
|
opts = *tt.args.userSSHOpts
|
||||||
|
}
|
||||||
|
cert, err := signSSHCertificate(key.Public().Key, opts, got, signer.Key.(crypto.Signer))
|
||||||
|
if (err != nil) != tt.wantSignErr {
|
||||||
|
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||||
|
} else {
|
||||||
|
if tt.wantSignErr {
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
34
authority/provisioner/method.go
Normal file
34
authority/provisioner/method.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Method indicates the action to action that we will perform, it's used as part
|
||||||
|
// of the context in the call to authorize. It defaults to Sing.
|
||||||
|
type Method int
|
||||||
|
|
||||||
|
// The key to save the Method in the context.
|
||||||
|
type methodKey struct{}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SignMethod is the method used to sign X.509 certificates.
|
||||||
|
SignMethod Method = iota
|
||||||
|
// SignSSHMethod is the method used to sign SSH certificate.
|
||||||
|
SignSSHMethod
|
||||||
|
// RevokeMethod is the method used to revoke X.509 certificates.
|
||||||
|
RevokeMethod
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewContextWithMethod creates a new context from ctx and attaches method to
|
||||||
|
// it.
|
||||||
|
func NewContextWithMethod(ctx context.Context, method Method) context.Context {
|
||||||
|
return context.WithValue(ctx, methodKey{}, method)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MethodFromContext returns the Method saved in ctx. Returns Sign if the given
|
||||||
|
// context has no Method associated with it.
|
||||||
|
func MethodFromContext(ctx context.Context) Method {
|
||||||
|
m, _ := ctx.Value(methodKey{}).(Method)
|
||||||
|
return m
|
||||||
|
}
|
|
@ -1,6 +1,9 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import "crypto/x509"
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
)
|
||||||
|
|
||||||
// noop provisioners is a provisioner that accepts anything.
|
// noop provisioners is a provisioner that accepts anything.
|
||||||
type noop struct{}
|
type noop struct{}
|
||||||
|
@ -28,7 +31,7 @@ func (p *noop) Init(config Config) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *noop) AuthorizeSign(token string) ([]SignOption, error) {
|
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
return []SignOption{}, nil
|
return []SignOption{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -21,7 +22,8 @@ func Test_noop(t *testing.T) {
|
||||||
assert.Equals(t, "", key)
|
assert.Equals(t, "", key)
|
||||||
assert.Equals(t, false, ok)
|
assert.Equals(t, false, ok)
|
||||||
|
|
||||||
sigOptions, err := p.AuthorizeSign("foo")
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
|
sigOptions, err := p.AuthorizeSign(ctx, "foo")
|
||||||
assert.Equals(t, []SignOption{}, sigOptions)
|
assert.Equals(t, []SignOption{}, sigOptions)
|
||||||
assert.Equals(t, nil, err)
|
assert.Equals(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -259,12 +260,29 @@ func (o *OIDC) AuthorizeRevoke(token string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign validates the given token.
|
// AuthorizeSign validates the given token.
|
||||||
func (o *OIDC) AuthorizeSign(token string) ([]SignOption, error) {
|
func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
claims, err := o.authorizeToken(token)
|
claims, err := o.authorizeToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for the sign ssh method, default to sign X.509
|
||||||
|
if m := MethodFromContext(ctx); m == SignSSHMethod {
|
||||||
|
if o.claimer.IsSSHCAEnabled() == false {
|
||||||
|
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", o.GetID())
|
||||||
|
}
|
||||||
|
return o.authorizeSSHSign(claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admins should be able to authorize any SAN
|
||||||
|
if o.IsAdmin(claims.Email) {
|
||||||
|
return []SignOption{
|
||||||
|
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
|
||||||
|
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||||
|
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
so := []SignOption{
|
so := []SignOption{
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
|
||||||
|
@ -287,6 +305,42 @@ func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
|
func (o *OIDC) authorizeSSHSign(claims *openIDPayload) ([]SignOption, error) {
|
||||||
|
signOptions := []SignOption{
|
||||||
|
// set the key id to the token subject
|
||||||
|
sshCertificateKeyIDModifier(claims.Email),
|
||||||
|
}
|
||||||
|
|
||||||
|
name := SanitizeSSHUserPrincipal(claims.Email)
|
||||||
|
if !sshUserRegex.MatchString(name) {
|
||||||
|
return nil, errors.Errorf("invalid principal '%s' from email address '%s'", name, claims.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admin users will default to user + name but they can be changed by the
|
||||||
|
// user options. Non-admins are only able to sign user certificates.
|
||||||
|
defaults := SSHOptions{
|
||||||
|
CertType: SSHUserCert,
|
||||||
|
Principals: []string{name},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !o.IsAdmin(claims.Email) {
|
||||||
|
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to a user with name as principal if not set
|
||||||
|
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||||
|
|
||||||
|
return append(signOptions,
|
||||||
|
// set the default extensions
|
||||||
|
&sshDefaultExtensionModifier{},
|
||||||
|
// checks the validity bounds, and set the validity if has not been set
|
||||||
|
&sshCertificateValidityModifier{o.claimer},
|
||||||
|
// require all the fields in the SSH certificate
|
||||||
|
&sshCertificateDefaultValidator{},
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
||||||
func getAndDecode(uri string, v interface{}) error {
|
func getAndDecode(uri string, v interface{}) error {
|
||||||
resp, err := http.Get(uri)
|
resp, err := http.Get(uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -276,7 +278,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.prov.AuthorizeSign(tt.args.token)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
|
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -286,7 +289,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
||||||
} else {
|
} else {
|
||||||
assert.NotNil(t, got)
|
assert.NotNil(t, got)
|
||||||
if tt.name == "admin" {
|
if tt.name == "admin" {
|
||||||
assert.Len(t, 4, got)
|
assert.Len(t, 3, got)
|
||||||
} else {
|
} else {
|
||||||
assert.Len(t, 5, got)
|
assert.Len(t, 5, got)
|
||||||
}
|
}
|
||||||
|
@ -295,6 +298,117 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
|
||||||
|
tm, fn := mockNow()
|
||||||
|
defer fn()
|
||||||
|
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
var keys jose.JSONWebKeySet
|
||||||
|
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||||
|
|
||||||
|
// Create test provisioners
|
||||||
|
p1, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p3, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// Admin + Domains
|
||||||
|
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
|
||||||
|
p3.Domains = []string{"smallstep.com"}
|
||||||
|
|
||||||
|
// Update configuration endpoints and initialize
|
||||||
|
config := Config{Claims: globalProvisionerClaims}
|
||||||
|
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||||
|
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||||
|
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||||
|
assert.FatalError(t, p1.Init(config))
|
||||||
|
assert.FatalError(t, p2.Init(config))
|
||||||
|
assert.FatalError(t, p3.Init(config))
|
||||||
|
|
||||||
|
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// Admin email not in domains
|
||||||
|
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{}, time.Now(), &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// Invalid email
|
||||||
|
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
key, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
signer, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||||
|
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||||
|
expectedUserOptions := &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"name"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||||
|
}
|
||||||
|
expectedAdminOptions := &SSHOptions{
|
||||||
|
CertType: "user", Principals: []string{"root"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||||
|
}
|
||||||
|
expectedHostOptions := &SSHOptions{
|
||||||
|
CertType: "host", Principals: []string{"smallstep.com"},
|
||||||
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
sshOpts SSHOptions
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *OIDC
|
||||||
|
args args
|
||||||
|
expected *SSHOptions
|
||||||
|
wantErr bool
|
||||||
|
wantSignErr bool
|
||||||
|
}{
|
||||||
|
{"ok", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false},
|
||||||
|
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false},
|
||||||
|
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||||
|
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||||
|
{"admin", p3, args{okAdmin, SSHOptions{}}, expectedAdminOptions, false, false},
|
||||||
|
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}}, expectedAdminOptions, false, false},
|
||||||
|
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}}, expectedAdminOptions, false, false},
|
||||||
|
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||||
|
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
|
||||||
|
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}}, nil, false, true},
|
||||||
|
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}}, nil, false, true},
|
||||||
|
{"fail-email", p3, args{failEmail, SSHOptions{}}, nil, true, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||||
|
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
assert.Nil(t, got)
|
||||||
|
} else if assert.NotNil(t, got) {
|
||||||
|
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||||
|
if (err != nil) != tt.wantSignErr {
|
||||||
|
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||||
|
} else {
|
||||||
|
if tt.wantSignErr {
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
||||||
srv := generateJWKServer(2)
|
srv := generateJWKServer(2)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -17,7 +19,7 @@ type Interface interface {
|
||||||
GetType() Type
|
GetType() Type
|
||||||
GetEncryptedKey() (kid string, key string, ok bool)
|
GetEncryptedKey() (kid string, key string, ok bool)
|
||||||
Init(config Config) error
|
Init(config Config) error
|
||||||
AuthorizeSign(token string) ([]SignOption, error)
|
AuthorizeSign(ctx context.Context, token string) ([]SignOption, error)
|
||||||
AuthorizeRenewal(cert *x509.Certificate) error
|
AuthorizeRenewal(cert *x509.Certificate) error
|
||||||
AuthorizeRevoke(token string) error
|
AuthorizeRevoke(token string) error
|
||||||
}
|
}
|
||||||
|
@ -169,3 +171,29 @@ func (l *List) UnmarshalJSON(data []byte) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
|
||||||
|
|
||||||
|
// SanitizeSSHUserPrincipal grabs an email or a string with the format
|
||||||
|
// local@domain and returns a sanitized version of the local, valid to be used
|
||||||
|
// as a user name. If the email starts with a letter between a and z, the
|
||||||
|
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
|
||||||
|
func SanitizeSSHUserPrincipal(email string) string {
|
||||||
|
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||||
|
email = email[:i]
|
||||||
|
}
|
||||||
|
return strings.Map(func(r rune) rune {
|
||||||
|
switch {
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
return r
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
return r
|
||||||
|
case r == '-':
|
||||||
|
return '-'
|
||||||
|
case r == '.': // drop dots
|
||||||
|
return -1
|
||||||
|
default:
|
||||||
|
return '_'
|
||||||
|
}
|
||||||
|
}, strings.ToLower(email))
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
func TestType_String(t *testing.T) {
|
func TestType_String(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -24,3 +26,29 @@ func TestType_String(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSanitizeSSHUserPrincipal(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
email string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"simple", args{"foobar"}, "foobar"},
|
||||||
|
{"camelcase", args{"FooBar"}, "foobar"},
|
||||||
|
{"email", args{"foo@example.com"}, "foo"},
|
||||||
|
{"email with dots", args{"foo.bar.zar@example.com"}, "foobarzar"},
|
||||||
|
{"email with dashes", args{"foo-bar-zar@example.com"}, "foo-bar-zar"},
|
||||||
|
{"email with underscores", args{"foo_bar_zar@example.com"}, "foo_bar_zar"},
|
||||||
|
{"email with symbols", args{"Foo.Bar0123456789!#$%&'*+-/=?^_`{|}~;@example.com"}, "foobar0123456789________-___________"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := SanitizeSSHUserPrincipal(tt.args.email); got != tt.want {
|
||||||
|
t.Errorf("SanitizeSSHUserPrincipal() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
306
authority/provisioner/sign_ssh_options.go
Normal file
306
authority/provisioner/sign_ssh_options.go
Normal file
|
@ -0,0 +1,306 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SSHUserCert is the string used to represent ssh.UserCert.
|
||||||
|
SSHUserCert = "user"
|
||||||
|
|
||||||
|
// SSHHostCert is the string used to represent ssh.HostCert.
|
||||||
|
SSHHostCert = "host"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSHCertificateModifier is the interface used to change properties in an SSH
|
||||||
|
// certificate.
|
||||||
|
type SSHCertificateModifier interface {
|
||||||
|
SignOption
|
||||||
|
Modify(cert *ssh.Certificate) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificateOptionModifier is the interface used to add custom options used
|
||||||
|
// to modify the SSH certificate.
|
||||||
|
type SSHCertificateOptionModifier interface {
|
||||||
|
SignOption
|
||||||
|
Option(o SSHOptions) SSHCertificateModifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificateValidator is the interface used to validate an SSH certificate.
|
||||||
|
type SSHCertificateValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(cert *ssh.Certificate) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHCertificateOptionsValidator is the interface used to validate the custom
|
||||||
|
// options used to modify the SSH certificate.
|
||||||
|
type SSHCertificateOptionsValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(got SSHOptions) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHOptions contains the options that can be passed to the SignSSH method.
|
||||||
|
type SSHOptions struct {
|
||||||
|
CertType string `json:"certType"`
|
||||||
|
Principals []string `json:"principals"`
|
||||||
|
ValidAfter TimeDuration `json:"validAfter,omitempty"`
|
||||||
|
ValidBefore TimeDuration `json:"validBefore,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the uint32 representation of the CertType.
|
||||||
|
func (o SSHOptions) Type() uint32 {
|
||||||
|
return sshCertTypeUInt32(o.CertType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate.
|
||||||
|
func (o SSHOptions) Modify(cert *ssh.Certificate) error {
|
||||||
|
switch o.CertType {
|
||||||
|
case "": // ignore
|
||||||
|
case SSHUserCert:
|
||||||
|
cert.CertType = ssh.UserCert
|
||||||
|
case SSHHostCert:
|
||||||
|
cert.CertType = ssh.HostCert
|
||||||
|
default:
|
||||||
|
return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType)
|
||||||
|
}
|
||||||
|
cert.ValidPrincipals = o.Principals
|
||||||
|
if !o.ValidAfter.IsZero() {
|
||||||
|
cert.ValidAfter = uint64(o.ValidAfter.Time().Unix())
|
||||||
|
}
|
||||||
|
if !o.ValidBefore.IsZero() {
|
||||||
|
cert.ValidBefore = uint64(o.ValidBefore.Time().Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
|
||||||
|
return errors.New("ssh certificate valid after cannot be greater than valid before")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// match compares two SSHOptions and return an error if they don't match. It
|
||||||
|
// ignores zero values.
|
||||||
|
func (o SSHOptions) match(got SSHOptions) error {
|
||||||
|
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
|
||||||
|
return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
|
||||||
|
}
|
||||||
|
if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) {
|
||||||
|
return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
|
||||||
|
}
|
||||||
|
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
|
||||||
|
return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
|
||||||
|
}
|
||||||
|
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
|
||||||
|
return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given
|
||||||
|
// Key ID in the SSH certificate.
|
||||||
|
type sshCertificateKeyIDModifier string
|
||||||
|
|
||||||
|
func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.KeyId = string(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateCertTypeModifier is an SSHCertificateModifier that sets the
|
||||||
|
// certificate type to the SSH certificate.
|
||||||
|
type sshCertificateCertTypeModifier string
|
||||||
|
|
||||||
|
func (m sshCertificateCertTypeModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.CertType = sshCertTypeUInt32(string(m))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificatePrincipalsModifier is an SSHCertificateModifier that sets the
|
||||||
|
// principals to the SSH certificate.
|
||||||
|
type sshCertificatePrincipalsModifier []string
|
||||||
|
|
||||||
|
func (m sshCertificatePrincipalsModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.ValidPrincipals = []string(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the
|
||||||
|
// ValidAfter in the SSH certificate.
|
||||||
|
type sshCertificateValidAfterModifier uint64
|
||||||
|
|
||||||
|
func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.ValidAfter = uint64(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the
|
||||||
|
// ValidBefore in the SSH certificate.
|
||||||
|
type sshCertificateValidBeforeModifier uint64
|
||||||
|
|
||||||
|
func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
cert.ValidBefore = uint64(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateDefaultModifier implements a SSHCertificateModifier that
|
||||||
|
// modifies the certificate with the given options if they are not set.
|
||||||
|
type sshCertificateDefaultsModifier SSHOptions
|
||||||
|
|
||||||
|
// Modify implements the SSHCertificateModifier interface.
|
||||||
|
func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
if cert.CertType == 0 {
|
||||||
|
cert.CertType = sshCertTypeUInt32(m.CertType)
|
||||||
|
}
|
||||||
|
if len(cert.ValidPrincipals) == 0 {
|
||||||
|
cert.ValidPrincipals = m.Principals
|
||||||
|
}
|
||||||
|
if cert.ValidAfter == 0 && !m.ValidAfter.IsZero() {
|
||||||
|
cert.ValidAfter = uint64(m.ValidAfter.Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidBefore == 0 && !m.ValidBefore.IsZero() {
|
||||||
|
cert.ValidBefore = uint64(m.ValidBefore.Unix())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets
|
||||||
|
// the default extensions in an SSH certificate.
|
||||||
|
type sshDefaultExtensionModifier struct{}
|
||||||
|
|
||||||
|
func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
switch cert.CertType {
|
||||||
|
// Default to no extensions for HostCert.
|
||||||
|
case ssh.HostCert:
|
||||||
|
return nil
|
||||||
|
case ssh.UserCert:
|
||||||
|
if cert.Extensions == nil {
|
||||||
|
cert.Extensions = make(map[string]string)
|
||||||
|
}
|
||||||
|
cert.Extensions["permit-X11-forwarding"] = ""
|
||||||
|
cert.Extensions["permit-agent-forwarding"] = ""
|
||||||
|
cert.Extensions["permit-port-forwarding"] = ""
|
||||||
|
cert.Extensions["permit-pty"] = ""
|
||||||
|
cert.Extensions["permit-user-rc"] = ""
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.New("ssh certificate type has not been set or is invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateValidityModifier is a SSHCertificateModifier checks the
|
||||||
|
// validity bounds, setting them if they are not provided. It will fail if a
|
||||||
|
// CertType has not been set or is not valid.
|
||||||
|
type sshCertificateValidityModifier struct {
|
||||||
|
*Claimer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *sshCertificateValidityModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
var d, min, max time.Duration
|
||||||
|
switch cert.CertType {
|
||||||
|
case ssh.UserCert:
|
||||||
|
d = m.DefaultUserSSHCertDuration()
|
||||||
|
min = m.MinUserSSHCertDuration()
|
||||||
|
max = m.MaxUserSSHCertDuration()
|
||||||
|
case ssh.HostCert:
|
||||||
|
d = m.DefaultHostSSHCertDuration()
|
||||||
|
min = m.MinHostSSHCertDuration()
|
||||||
|
max = m.MaxHostSSHCertDuration()
|
||||||
|
case 0:
|
||||||
|
return errors.New("ssh certificate type has not been set")
|
||||||
|
default:
|
||||||
|
return errors.Errorf("unknown ssh certificate type %d", cert.CertType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cert.ValidAfter == 0 {
|
||||||
|
cert.ValidAfter = uint64(now().Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidBefore == 0 {
|
||||||
|
t := time.Unix(int64(cert.ValidAfter), 0)
|
||||||
|
cert.ValidBefore = uint64(t.Add(d).Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
diff := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
|
||||||
|
switch {
|
||||||
|
case diff < min:
|
||||||
|
return errors.Errorf("ssh certificate duration cannot be lower than %s", min)
|
||||||
|
case diff > max:
|
||||||
|
return errors.Errorf("ssh certificate duration cannot be greater than %s", max)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateOptionsValidator validates the user SSHOptions with the ones
|
||||||
|
// usually present in the token.
|
||||||
|
type sshCertificateOptionsValidator SSHOptions
|
||||||
|
|
||||||
|
// Valid implements SSHCertificateOptionsValidator and returns nil if both
|
||||||
|
// SSHOptions match.
|
||||||
|
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error {
|
||||||
|
want := SSHOptions(v)
|
||||||
|
return want.match(got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertificateDefaultValidator implements a simple validator for all the
|
||||||
|
// fields in the SSH certificate.
|
||||||
|
type sshCertificateDefaultValidator struct{}
|
||||||
|
|
||||||
|
// Valid returns an error if the given certificate does not contain the necessary fields.
|
||||||
|
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
|
||||||
|
switch {
|
||||||
|
case len(cert.Nonce) == 0:
|
||||||
|
return errors.New("ssh certificate nonce cannot be empty")
|
||||||
|
case cert.Key == nil:
|
||||||
|
return errors.New("ssh certificate key cannot be nil")
|
||||||
|
case cert.Serial == 0:
|
||||||
|
return errors.New("ssh certificate serial cannot be 0")
|
||||||
|
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
|
||||||
|
return errors.Errorf("ssh certificate has an unknown type: %d", cert.CertType)
|
||||||
|
case cert.KeyId == "":
|
||||||
|
return errors.New("ssh certificate key id cannot be empty")
|
||||||
|
case len(cert.ValidPrincipals) == 0:
|
||||||
|
return errors.New("ssh certificate valid principals cannot be empty")
|
||||||
|
case cert.ValidAfter == 0:
|
||||||
|
return errors.New("ssh certificate valid after cannot be 0")
|
||||||
|
case cert.ValidBefore == 0:
|
||||||
|
return errors.New("ssh certificate valid before cannot be 0")
|
||||||
|
case cert.CertType == ssh.UserCert && len(cert.Extensions) == 0:
|
||||||
|
return errors.New("ssh certificate extensions cannot be empty")
|
||||||
|
case cert.SignatureKey == nil:
|
||||||
|
return errors.New("ssh certificate signature key cannot be nil")
|
||||||
|
case cert.Signature == nil:
|
||||||
|
return errors.New("ssh certificate signature cannot be nil")
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshCertTypeUInt32
|
||||||
|
func sshCertTypeUInt32(ct string) uint32 {
|
||||||
|
switch ct {
|
||||||
|
case SSHUserCert:
|
||||||
|
return ssh.UserCert
|
||||||
|
case SSHHostCert:
|
||||||
|
return ssh.HostCert
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsAllMembers reports whether all members of subgroup are within group.
|
||||||
|
func containsAllMembers(group, subgroup []string) bool {
|
||||||
|
lg, lsg := len(group), len(subgroup)
|
||||||
|
if lsg > lg || (lg > 0 && lsg == 0) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
visit := make(map[string]struct{}, lg)
|
||||||
|
for i := 0; i < lg; i++ {
|
||||||
|
visit[group[i]] = struct{}{}
|
||||||
|
}
|
||||||
|
for i := 0; i < lsg; i++ {
|
||||||
|
if _, ok := visit[subgroup[i]]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
125
authority/provisioner/ssh_test.go
Normal file
125
authority/provisioner/ssh_test.go
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func validateSSHCertificate(cert *ssh.Certificate, opts *SSHOptions) error {
|
||||||
|
switch {
|
||||||
|
case cert == nil:
|
||||||
|
return fmt.Errorf("certificate is nil")
|
||||||
|
case cert.Signature == nil:
|
||||||
|
return fmt.Errorf("certificate signature is nil")
|
||||||
|
case cert.SignatureKey == nil:
|
||||||
|
return fmt.Errorf("certificate signature is nil")
|
||||||
|
case !reflect.DeepEqual(cert.ValidPrincipals, opts.Principals):
|
||||||
|
return fmt.Errorf("certificate principals are not equal, want %v, got %v", opts.Principals, cert.ValidPrincipals)
|
||||||
|
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
|
||||||
|
return fmt.Errorf("certificate type %v is not valid", cert.CertType)
|
||||||
|
case opts.CertType == "user" && cert.CertType != ssh.UserCert:
|
||||||
|
return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.UserCert, cert.CertType)
|
||||||
|
case opts.CertType == "host" && cert.CertType != ssh.HostCert:
|
||||||
|
return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.HostCert, cert.CertType)
|
||||||
|
case cert.ValidAfter != uint64(opts.ValidAfter.Unix()):
|
||||||
|
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
|
||||||
|
case cert.ValidBefore != uint64(opts.ValidBefore.Unix()):
|
||||||
|
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
|
||||||
|
case opts.CertType == "user" && len(cert.Extensions) != 5:
|
||||||
|
return fmt.Errorf("certificate extensions number is invalid, want 5, got %d", len(cert.Extensions))
|
||||||
|
case opts.CertType == "host" && len(cert.Extensions) != 0:
|
||||||
|
return fmt.Errorf("certificate extensions number is invalid, want 0, got %d", len(cert.Extensions))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOption, signKey crypto.Signer) (*ssh.Certificate, error) {
|
||||||
|
pub, err := ssh.NewPublicKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var mods []SSHCertificateModifier
|
||||||
|
var validators []SSHCertificateValidator
|
||||||
|
|
||||||
|
for _, op := range signOpts {
|
||||||
|
switch o := op.(type) {
|
||||||
|
// modify the ssh.Certificate
|
||||||
|
case SSHCertificateModifier:
|
||||||
|
mods = append(mods, o)
|
||||||
|
// modify the ssh.Certificate given the SSHOptions
|
||||||
|
case SSHCertificateOptionModifier:
|
||||||
|
mods = append(mods, o.Option(opts))
|
||||||
|
// validate the ssh.Certificate
|
||||||
|
case SSHCertificateValidator:
|
||||||
|
validators = append(validators, o)
|
||||||
|
// validate the given SSHOptions
|
||||||
|
case SSHCertificateOptionsValidator:
|
||||||
|
if err := o.Valid(opts); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("signSSH: invalid extra option type %T", o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build base certificate with the key and some random values
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0},
|
||||||
|
Key: pub,
|
||||||
|
Serial: 1234567890,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use opts to modify the certificate
|
||||||
|
if err := opts.Modify(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use provisioner modifiers
|
||||||
|
for _, m := range mods {
|
||||||
|
if err := m.Modify(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get signer from authority keys
|
||||||
|
var signer ssh.Signer
|
||||||
|
switch cert.CertType {
|
||||||
|
case ssh.UserCert:
|
||||||
|
signer, err = ssh.NewSignerFromSigner(signKey)
|
||||||
|
case ssh.HostCert:
|
||||||
|
signer, err = ssh.NewSignerFromSigner(signKey)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected ssh certificate type: %d", cert.CertType)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert.SignatureKey = signer.PublicKey()
|
||||||
|
|
||||||
|
// Get bytes for signing trailing the signature length.
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
|
||||||
|
// Sign the certificate
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
|
||||||
|
// User provisioners validators
|
||||||
|
for _, v := range validators {
|
||||||
|
if err := v.Valid(cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
|
@ -57,6 +57,17 @@ func (t *TimeDuration) SetTime(tt time.Time) {
|
||||||
t.t, t.d = tt, 0
|
t.t, t.d = tt, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsZero returns true the TimeDuration represents the zero value, false
|
||||||
|
// otherwise.
|
||||||
|
func (t *TimeDuration) IsZero() bool {
|
||||||
|
return t.t.IsZero() && t.d == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equal returns if t and other are equal.
|
||||||
|
func (t *TimeDuration) Equal(other *TimeDuration) bool {
|
||||||
|
return t.t.Equal(other.t) && t.d == other.d
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalJSON implements the json.Marshaler interface. If the time is set it
|
// MarshalJSON implements the json.Marshaler interface. If the time is set it
|
||||||
// will return the time in RFC 3339 format if not it will return the duration
|
// will return the time in RFC 3339 format if not it will return the duration
|
||||||
// string.
|
// string.
|
||||||
|
@ -64,7 +75,7 @@ func (t TimeDuration) MarshalJSON() ([]byte, error) {
|
||||||
switch {
|
switch {
|
||||||
case t.t.IsZero():
|
case t.t.IsZero():
|
||||||
if t.d == 0 {
|
if t.d == 0 {
|
||||||
return []byte("null"), nil
|
return []byte(`""`), nil
|
||||||
}
|
}
|
||||||
return json.Marshal(t.d.String())
|
return json.Marshal(t.d.String())
|
||||||
default:
|
default:
|
||||||
|
@ -102,11 +113,16 @@ func (t *TimeDuration) UnmarshalJSON(data []byte) error {
|
||||||
return errors.Errorf("failed to parse %s", data)
|
return errors.Errorf("failed to parse %s", data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Time calculates the embedded time.Time, sets it if necessary, and returns it.
|
// Time calculates the time if needed and returns it.
|
||||||
func (t *TimeDuration) Time() time.Time {
|
func (t *TimeDuration) Time() time.Time {
|
||||||
return t.RelativeTime(now())
|
return t.RelativeTime(now())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unix calculates the time if needed it and returns the Unix time in seconds.
|
||||||
|
func (t *TimeDuration) Unix() int64 {
|
||||||
|
return t.RelativeTime(now()).Unix()
|
||||||
|
}
|
||||||
|
|
||||||
// RelativeTime returns the embedded time.Time or the base time plus the
|
// RelativeTime returns the embedded time.Time or the base time plus the
|
||||||
// duration if this is not zero.
|
// duration if this is not zero.
|
||||||
func (t *TimeDuration) RelativeTime(base time.Time) time.Time {
|
func (t *TimeDuration) RelativeTime(base time.Time) time.Time {
|
||||||
|
|
|
@ -6,6 +6,17 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func mockNow() (time.Time, func()) {
|
||||||
|
tm := time.Unix(1584198566, 535897000).UTC()
|
||||||
|
nowFn := now
|
||||||
|
now = func() time.Time {
|
||||||
|
return tm
|
||||||
|
}
|
||||||
|
return tm, func() {
|
||||||
|
now = nowFn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewTimeDuration(t *testing.T) {
|
func TestNewTimeDuration(t *testing.T) {
|
||||||
tm := time.Unix(1584198566, 535897000).UTC()
|
tm := time.Unix(1584198566, 535897000).UTC()
|
||||||
type args struct {
|
type args struct {
|
||||||
|
@ -137,7 +148,7 @@ func TestTimeDuration_MarshalJSON(t *testing.T) {
|
||||||
want []byte
|
want []byte
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"null", TimeDuration{}, []byte("null"), false},
|
{"empty", TimeDuration{}, []byte(`""`), false},
|
||||||
{"timestamp", TimeDuration{t: tm}, []byte(`"2020-03-14T15:09:26.535897Z"`), false},
|
{"timestamp", TimeDuration{t: tm}, []byte(`"2020-03-14T15:09:26.535897Z"`), false},
|
||||||
{"duration", TimeDuration{d: 1 * time.Hour}, []byte(`"1h0m0s"`), false},
|
{"duration", TimeDuration{d: 1 * time.Hour}, []byte(`"1h0m0s"`), false},
|
||||||
{"fail", TimeDuration{t: time.Date(-1, 0, 0, 0, 0, 0, 0, time.UTC)}, nil, true},
|
{"fail", TimeDuration{t: time.Date(-1, 0, 0, 0, 0, 0, 0, time.UTC)}, nil, true},
|
||||||
|
@ -166,7 +177,7 @@ func TestTimeDuration_UnmarshalJSON(t *testing.T) {
|
||||||
want *TimeDuration
|
want *TimeDuration
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"null", args{[]byte("null")}, &TimeDuration{}, false},
|
{"empty", args{[]byte(`""`)}, &TimeDuration{}, false},
|
||||||
{"timestamp", args{[]byte(`"2020-03-14T15:09:26.535897Z"`)}, &TimeDuration{t: time.Unix(1584198566, 535897000).UTC()}, false},
|
{"timestamp", args{[]byte(`"2020-03-14T15:09:26.535897Z"`)}, &TimeDuration{t: time.Unix(1584198566, 535897000).UTC()}, false},
|
||||||
{"duration", args{[]byte(`"1h"`)}, &TimeDuration{d: time.Hour}, false},
|
{"duration", args{[]byte(`"1h"`)}, &TimeDuration{d: time.Hour}, false},
|
||||||
{"fail", args{[]byte("123")}, &TimeDuration{}, true},
|
{"fail", args{[]byte("123")}, &TimeDuration{}, true},
|
||||||
|
@ -186,15 +197,8 @@ func TestTimeDuration_UnmarshalJSON(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimeDuration_Time(t *testing.T) {
|
func TestTimeDuration_Time(t *testing.T) {
|
||||||
nowFn := now
|
tm, fn := mockNow()
|
||||||
defer func() {
|
defer fn()
|
||||||
now = nowFn
|
|
||||||
now()
|
|
||||||
}()
|
|
||||||
tm := time.Unix(1584198566, 535897000).UTC()
|
|
||||||
now = func() time.Time {
|
|
||||||
return tm
|
|
||||||
}
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
timeDuration *TimeDuration
|
timeDuration *TimeDuration
|
||||||
|
@ -211,6 +215,30 @@ func TestTimeDuration_Time(t *testing.T) {
|
||||||
got := tt.timeDuration.Time()
|
got := tt.timeDuration.Time()
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
t.Errorf("TimeDuration.Time() = %v, want %v", got, tt.want)
|
t.Errorf("TimeDuration.Time() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimeDuration_Unix(t *testing.T) {
|
||||||
|
tm, fn := mockNow()
|
||||||
|
defer fn()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
timeDuration *TimeDuration
|
||||||
|
want int64
|
||||||
|
}{
|
||||||
|
{"zero", nil, -62135596800},
|
||||||
|
{"zero", &TimeDuration{}, -62135596800},
|
||||||
|
{"timestamp", &TimeDuration{t: tm}, 1584198566},
|
||||||
|
{"local", &TimeDuration{t: tm.Local()}, 1584198566},
|
||||||
|
{"duration", &TimeDuration{d: 1 * time.Hour}, 1584202166},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.timeDuration.Unix()
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("TimeDuration.Unix() = %v, want %v", got, tt.want)
|
||||||
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -218,15 +246,8 @@ func TestTimeDuration_Time(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimeDuration_String(t *testing.T) {
|
func TestTimeDuration_String(t *testing.T) {
|
||||||
nowFn := now
|
tm, fn := mockNow()
|
||||||
defer func() {
|
defer fn()
|
||||||
now = nowFn
|
|
||||||
now()
|
|
||||||
}()
|
|
||||||
tm := time.Unix(1584198566, 535897000).UTC()
|
|
||||||
now = func() time.Time {
|
|
||||||
return tm
|
|
||||||
}
|
|
||||||
type fields struct {
|
type fields struct {
|
||||||
t time.Time
|
t time.Time
|
||||||
d time.Duration
|
d time.Duration
|
||||||
|
|
|
@ -480,6 +480,54 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
|
||||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateSimpleSSHUserToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
|
||||||
|
return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SSHOptions{
|
||||||
|
CertType: "user",
|
||||||
|
Principals: []string{"name"},
|
||||||
|
}, jwk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSimpleSSHHostToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
|
||||||
|
return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SSHOptions{
|
||||||
|
CertType: "host",
|
||||||
|
Principals: []string{"smallstep.com"},
|
||||||
|
}, jwk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *SSHOptions, jwk *jose.JSONWebKey) (string, error) {
|
||||||
|
sig, err := jose.NewSigner(
|
||||||
|
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||||
|
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := randutil.ASCII(64)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := struct {
|
||||||
|
jose.Claims
|
||||||
|
Step *stepPayload `json:"step,omitempty"`
|
||||||
|
}{
|
||||||
|
Claims: jose.Claims{
|
||||||
|
ID: id,
|
||||||
|
Subject: sub,
|
||||||
|
Issuer: iss,
|
||||||
|
IssuedAt: jose.NewNumericDate(iat),
|
||||||
|
NotBefore: jose.NewNumericDate(iat),
|
||||||
|
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
|
||||||
|
Audience: []string{aud},
|
||||||
|
},
|
||||||
|
Step: &stepPayload{
|
||||||
|
SSH: sshOpts,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
|
}
|
||||||
|
|
||||||
func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||||
sig, err := jose.NewSigner(
|
sig, err := jose.NewSigner(
|
||||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||||
|
|
|
@ -13,7 +13,7 @@ func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||||
key, ok := a.provisioners.LoadEncryptedKey(kid)
|
key, ok := a.provisioners.LoadEncryptedKey(kid)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
|
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
|
||||||
http.StatusNotFound, context{}}
|
http.StatusNotFound, apiCtx{}}
|
||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
|
||||||
p, ok := a.provisioners.LoadByCertificate(crt)
|
p, ok := a.provisioners.LoadByCertificate(crt)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &apiError{errors.Errorf("provisioner not found"),
|
return nil, &apiError{errors.Errorf("provisioner not found"),
|
||||||
http.StatusNotFound, context{}}
|
http.StatusNotFound, apiCtx{}}
|
||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ func TestGetEncryptedKey(t *testing.T) {
|
||||||
a: a,
|
a: a,
|
||||||
kid: "foo",
|
kid: "foo",
|
||||||
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
|
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
|
||||||
http.StatusNotFound, context{}},
|
http.StatusNotFound, apiCtx{}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,13 +12,13 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
||||||
val, ok := a.certificates.Load(sum)
|
val, ok := a.certificates.Load(sum)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum),
|
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum),
|
||||||
http.StatusNotFound, context{}}
|
http.StatusNotFound, apiCtx{}}
|
||||||
}
|
}
|
||||||
|
|
||||||
crt, ok := val.(*x509.Certificate)
|
crt, ok := val.(*x509.Certificate)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||||
http.StatusInternalServerError, context{}}
|
http.StatusInternalServerError, apiCtx{}}
|
||||||
}
|
}
|
||||||
return crt, nil
|
return crt, nil
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error)
|
||||||
if !ok {
|
if !ok {
|
||||||
federation = nil
|
federation = nil
|
||||||
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||||
http.StatusInternalServerError, context{}}
|
http.StatusInternalServerError, apiCtx{}}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
federation = append(federation, crt)
|
federation = append(federation, crt)
|
||||||
|
|
|
@ -19,8 +19,8 @@ func TestRoot(t *testing.T) {
|
||||||
sum string
|
sum string
|
||||||
err *apiError
|
err *apiError
|
||||||
}{
|
}{
|
||||||
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
|
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}},
|
||||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}},
|
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}},
|
||||||
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
|
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
239
authority/ssh.go
Normal file
239
authority/ssh.go
Normal file
|
@ -0,0 +1,239 @@
|
||||||
|
package authority
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SSHAddUserPrincipal is the principal that will run the add user command.
|
||||||
|
// Defaults to "provisioner" but it can be changed in the configuration.
|
||||||
|
SSHAddUserPrincipal = "provisioner"
|
||||||
|
|
||||||
|
// SSHAddUserCommand is the default command to run to add a new user.
|
||||||
|
// Defaults to "sudo useradd -m <principal>; nc -q0 localhost 22" but it can be changed in the
|
||||||
|
// configuration. The string "<principal>" will be replace by the new
|
||||||
|
// principal to add.
|
||||||
|
SSHAddUserCommand = "sudo useradd -m <principal>; nc -q0 localhost 22"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SignSSH creates a signed SSH certificate with the given public key and options.
|
||||||
|
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
|
var mods []provisioner.SSHCertificateModifier
|
||||||
|
var validators []provisioner.SSHCertificateValidator
|
||||||
|
|
||||||
|
for _, op := range signOpts {
|
||||||
|
switch o := op.(type) {
|
||||||
|
// modify the ssh.Certificate
|
||||||
|
case provisioner.SSHCertificateModifier:
|
||||||
|
mods = append(mods, o)
|
||||||
|
// modify the ssh.Certificate given the SSHOptions
|
||||||
|
case provisioner.SSHCertificateOptionModifier:
|
||||||
|
mods = append(mods, o.Option(opts))
|
||||||
|
// validate the ssh.Certificate
|
||||||
|
case provisioner.SSHCertificateValidator:
|
||||||
|
validators = append(validators, o)
|
||||||
|
// validate the given SSHOptions
|
||||||
|
case provisioner.SSHCertificateOptionsValidator:
|
||||||
|
if err := o.Valid(opts); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Errorf("signSSH: invalid extra option type %T", o),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, err := randutil.ASCII(32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||||
|
}
|
||||||
|
|
||||||
|
var serial uint64
|
||||||
|
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSH: error reading random number"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build base certificate with the key and some random values
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte(nonce),
|
||||||
|
Key: key,
|
||||||
|
Serial: serial,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use opts to modify the certificate
|
||||||
|
if err := opts.Modify(cert); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use provisioner modifiers
|
||||||
|
for _, m := range mods {
|
||||||
|
if err := m.Modify(cert); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get signer from authority keys
|
||||||
|
var signer ssh.Signer
|
||||||
|
switch cert.CertType {
|
||||||
|
case ssh.UserCert:
|
||||||
|
if a.sshCAUserCertSignKey == nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSH: user certificate signing is not enabled"),
|
||||||
|
code: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if signer, err = ssh.NewSignerFromSigner(a.sshCAUserCertSignKey); err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSH: error creating signer"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case ssh.HostCert:
|
||||||
|
if a.sshCAHostCertSignKey == nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSH: host certificate signing is not enabled"),
|
||||||
|
code: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if signer, err = ssh.NewSignerFromSigner(a.sshCAHostCertSignKey); err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSH: error creating signer"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cert.SignatureKey = signer.PublicKey()
|
||||||
|
|
||||||
|
// Get bytes for signing trailing the signature length.
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
|
||||||
|
// Sign the certificate
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSH: error signing certificate"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
|
||||||
|
// User provisioners validators
|
||||||
|
for _, v := range validators {
|
||||||
|
if err := v.Valid(cert); err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignSSHAddUser signs a certificate that provisions a new user in a server.
|
||||||
|
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
|
if a.sshCAUserCertSignKey == nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSHAddUser: user certificate signing is not enabled"),
|
||||||
|
code: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if subject.CertType != ssh.UserCert {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSHProxy: certificate is not a user certificate"),
|
||||||
|
code: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(subject.ValidPrincipals) != 1 {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.New("signSSHProxy: certificate does not have only one principal"),
|
||||||
|
code: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, err := randutil.ASCII(32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||||
|
}
|
||||||
|
|
||||||
|
var serial uint64
|
||||||
|
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSHProxy: error reading random number"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := ssh.NewSignerFromSigner(a.sshCAUserCertSignKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "signSSHProxy: error creating signer"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
principal := subject.ValidPrincipals[0]
|
||||||
|
addUserPrincipal := a.getAddUserPrincipal()
|
||||||
|
|
||||||
|
cert := &ssh.Certificate{
|
||||||
|
Nonce: []byte(nonce),
|
||||||
|
Key: key,
|
||||||
|
Serial: serial,
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
KeyId: principal + "-" + addUserPrincipal,
|
||||||
|
ValidPrincipals: []string{addUserPrincipal},
|
||||||
|
ValidAfter: subject.ValidAfter,
|
||||||
|
ValidBefore: subject.ValidBefore,
|
||||||
|
Permissions: ssh.Permissions{
|
||||||
|
CriticalOptions: map[string]string{
|
||||||
|
"force-command": a.getAddUserCommand(principal),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SignatureKey: signer.PublicKey(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get bytes for signing trailing the signature length.
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
|
||||||
|
// Sign the certificate
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) getAddUserPrincipal() (cmd string) {
|
||||||
|
if a.config.SSH.AddUserPrincipal == "" {
|
||||||
|
return SSHAddUserPrincipal
|
||||||
|
}
|
||||||
|
return a.config.SSH.AddUserPrincipal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) getAddUserCommand(principal string) string {
|
||||||
|
var cmd string
|
||||||
|
if a.config.SSH.AddUserCommand == "" {
|
||||||
|
cmd = SSHAddUserCommand
|
||||||
|
} else {
|
||||||
|
cmd = a.config.SSH.AddUserCommand
|
||||||
|
}
|
||||||
|
return strings.Replace(cmd, "<principal>", principal, -1)
|
||||||
|
}
|
252
authority/ssh_test.go
Normal file
252
authority/ssh_test.go
Normal file
|
@ -0,0 +1,252 @@
|
||||||
|
package authority
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sshTestModifier ssh.Certificate
|
||||||
|
|
||||||
|
func (m sshTestModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
if m.CertType != 0 {
|
||||||
|
cert.CertType = m.CertType
|
||||||
|
}
|
||||||
|
if m.KeyId != "" {
|
||||||
|
cert.KeyId = m.KeyId
|
||||||
|
}
|
||||||
|
if m.ValidAfter != 0 {
|
||||||
|
cert.ValidAfter = m.ValidAfter
|
||||||
|
}
|
||||||
|
if m.ValidBefore != 0 {
|
||||||
|
cert.ValidBefore = m.ValidBefore
|
||||||
|
}
|
||||||
|
if len(m.ValidPrincipals) != 0 {
|
||||||
|
cert.ValidPrincipals = m.ValidPrincipals
|
||||||
|
}
|
||||||
|
if m.Permissions.CriticalOptions != nil {
|
||||||
|
cert.Permissions.CriticalOptions = m.Permissions.CriticalOptions
|
||||||
|
}
|
||||||
|
if m.Permissions.Extensions != nil {
|
||||||
|
cert.Permissions.Extensions = m.Permissions.Extensions
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshTestCertModifier string
|
||||||
|
|
||||||
|
func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error {
|
||||||
|
if m == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf(string(m))
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshTestCertValidator string
|
||||||
|
|
||||||
|
func (v sshTestCertValidator) Valid(crt *ssh.Certificate) error {
|
||||||
|
if v == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf(string(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshTestOptionsValidator string
|
||||||
|
|
||||||
|
func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
|
||||||
|
if v == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf(string(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshTestOptionsModifier string
|
||||||
|
|
||||||
|
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier {
|
||||||
|
return sshTestCertModifier(string(m))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthority_SignSSH(t *testing.T) {
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
pub, err := ssh.NewPublicKey(key.Public())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
userOptions := sshTestModifier{
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
}
|
||||||
|
hostOptions := sshTestModifier{
|
||||||
|
CertType: ssh.HostCert,
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
sshCAUserCertSignKey crypto.Signer
|
||||||
|
sshCAHostCertSignKey crypto.Signer
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
key ssh.PublicKey
|
||||||
|
opts provisioner.SSHOptions
|
||||||
|
signOpts []provisioner.SignOption
|
||||||
|
}
|
||||||
|
type want struct {
|
||||||
|
CertType uint32
|
||||||
|
Principals []string
|
||||||
|
ValidAfter uint64
|
||||||
|
ValidBefore uint64
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want want
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{hostOptions}}, want{CertType: ssh.HostCert}, false},
|
||||||
|
{"ok-opts-type-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-opts-type-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert}, false},
|
||||||
|
{"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
|
||||||
|
{"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false},
|
||||||
|
{"ok-opts-valid-after", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false},
|
||||||
|
{"ok-opts-valid-before", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false},
|
||||||
|
{"ok-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"fail-opts-type", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "foo"}, []provisioner.SignOption{}}, want{}, true},
|
||||||
|
{"fail-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("an error")}}, want{}, true},
|
||||||
|
{"fail-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("an error")}}, want{}, true},
|
||||||
|
{"fail-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("an error")}}, want{}, true},
|
||||||
|
{"fail-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("an error")}}, want{}, true},
|
||||||
|
{"fail-bad-sign-options", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, "wrong type"}}, want{}, true},
|
||||||
|
{"fail-no-user-key", fields{nil, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{}, true},
|
||||||
|
{"fail-no-host-key", fields{signKey, nil}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{}, true},
|
||||||
|
{"fail-bad-type", fields{signKey, nil}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{sshTestModifier{CertType: 0}}}, want{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := testAuthority(t)
|
||||||
|
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
|
||||||
|
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
|
||||||
|
|
||||||
|
got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil && assert.NotNil(t, got) {
|
||||||
|
assert.Equals(t, tt.want.CertType, got.CertType)
|
||||||
|
assert.Equals(t, tt.want.Principals, got.ValidPrincipals)
|
||||||
|
assert.Equals(t, tt.want.ValidAfter, got.ValidAfter)
|
||||||
|
assert.Equals(t, tt.want.ValidBefore, got.ValidBefore)
|
||||||
|
assert.NotNil(t, got.Key)
|
||||||
|
assert.NotNil(t, got.Nonce)
|
||||||
|
assert.NotEquals(t, 0, got.Serial)
|
||||||
|
assert.NotNil(t, got.Signature)
|
||||||
|
assert.NotNil(t, got.SignatureKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthority_SignSSHAddUser(t *testing.T) {
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
pub, err := ssh.NewPublicKey(key.Public())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
sshCAUserCertSignKey crypto.Signer
|
||||||
|
sshCAHostCertSignKey crypto.Signer
|
||||||
|
addUserPrincipal string
|
||||||
|
addUserCommand string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
key ssh.PublicKey
|
||||||
|
subject *ssh.Certificate
|
||||||
|
}
|
||||||
|
type want struct {
|
||||||
|
CertType uint32
|
||||||
|
Principals []string
|
||||||
|
ValidAfter uint64
|
||||||
|
ValidBefore uint64
|
||||||
|
ForceCommand string
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
validCert := &ssh.Certificate{
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
ValidPrincipals: []string{"user"},
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
}
|
||||||
|
validWant := want{
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
Principals: []string{"provisioner"},
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
ForceCommand: "sudo useradd -m user; nc -q0 localhost 22",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want want
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{signKey, signKey, "", ""}, args{pub, validCert}, validWant, false},
|
||||||
|
{"ok-no-host-key", fields{signKey, nil, "", ""}, args{pub, validCert}, validWant, false},
|
||||||
|
{"ok-custom-principal", fields{signKey, signKey, "my-principal", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "sudo useradd -m user; nc -q0 localhost 22"}, false},
|
||||||
|
{"ok-custom-command", fields{signKey, signKey, "", "foo <principal> <principal>"}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"provisioner"}, ForceCommand: "foo user user"}, false},
|
||||||
|
{"ok-custom-principal-and-command", fields{signKey, signKey, "my-principal", "foo <principal> <principal>"}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "foo user user"}, false},
|
||||||
|
{"fail-no-user-key", fields{nil, signKey, "", ""}, args{pub, validCert}, want{}, true},
|
||||||
|
{"fail-no-user-cert", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.HostCert, ValidPrincipals: []string{"foo"}}}, want{}, true},
|
||||||
|
{"fail-no-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{}}}, want{}, true},
|
||||||
|
{"fail-many-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}}}, want{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := testAuthority(t)
|
||||||
|
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
|
||||||
|
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
|
||||||
|
a.config.SSH = &SSHConfig{
|
||||||
|
AddUserPrincipal: tt.fields.addUserPrincipal,
|
||||||
|
AddUserCommand: tt.fields.addUserCommand,
|
||||||
|
}
|
||||||
|
got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil && assert.NotNil(t, got) {
|
||||||
|
assert.Equals(t, tt.want.CertType, got.CertType)
|
||||||
|
assert.Equals(t, tt.want.Principals, got.ValidPrincipals)
|
||||||
|
assert.Equals(t, tt.args.subject.ValidPrincipals[0]+"-"+tt.want.Principals[0], got.KeyId)
|
||||||
|
assert.Equals(t, tt.want.ValidAfter, got.ValidAfter)
|
||||||
|
assert.Equals(t, tt.want.ValidBefore, got.ValidBefore)
|
||||||
|
assert.Equals(t, map[string]string{"force-command": tt.want.ForceCommand}, got.CriticalOptions)
|
||||||
|
assert.Equals(t, nil, got.Extensions)
|
||||||
|
assert.NotNil(t, got.Key)
|
||||||
|
assert.NotNil(t, got.Nonce)
|
||||||
|
assert.NotEquals(t, 0, got.Serial)
|
||||||
|
assert.NotNil(t, got.Signature)
|
||||||
|
assert.NotNil(t, got.SignatureKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -58,7 +58,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
|
||||||
// Sign creates a signed certificate from a certificate signing request.
|
// Sign creates a signed certificate from a certificate signing request.
|
||||||
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
|
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
|
||||||
var (
|
var (
|
||||||
errContext = context{"csr": csr, "signOptions": signOpts}
|
errContext = apiCtx{"csr": csr, "signOptions": signOpts}
|
||||||
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
|
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
|
||||||
certValidators = []provisioner.CertificateValidator{}
|
certValidators = []provisioner.CertificateValidator{}
|
||||||
issIdentity = a.intermediateIdentity
|
issIdentity = a.intermediateIdentity
|
||||||
|
@ -181,23 +181,23 @@ func (a *Authority) Renew(oldCert *x509.Certificate) (*x509.Certificate, *x509.C
|
||||||
leaf, err := x509util.NewLeafProfileWithTemplate(newCert,
|
leaf, err := x509util.NewLeafProfileWithTemplate(newCert,
|
||||||
issIdentity.Crt, issIdentity.Key)
|
issIdentity.Crt, issIdentity.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, &apiError{err, http.StatusInternalServerError, context{}}
|
return nil, nil, &apiError{err, http.StatusInternalServerError, apiCtx{}}
|
||||||
}
|
}
|
||||||
crtBytes, err := leaf.CreateCertificate()
|
crtBytes, err := leaf.CreateCertificate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"),
|
return nil, nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"),
|
||||||
http.StatusInternalServerError, context{}}
|
http.StatusInternalServerError, apiCtx{}}
|
||||||
}
|
}
|
||||||
|
|
||||||
serverCert, err := x509.ParseCertificate(crtBytes)
|
serverCert, err := x509.ParseCertificate(crtBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, &apiError{errors.Wrap(err, "error parsing new server certificate"),
|
return nil, nil, &apiError{errors.Wrap(err, "error parsing new server certificate"),
|
||||||
http.StatusInternalServerError, context{}}
|
http.StatusInternalServerError, apiCtx{}}
|
||||||
}
|
}
|
||||||
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
|
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"),
|
return nil, nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"),
|
||||||
http.StatusInternalServerError, context{}}
|
http.StatusInternalServerError, apiCtx{}}
|
||||||
}
|
}
|
||||||
|
|
||||||
return serverCert, caCert, nil
|
return serverCert, caCert, nil
|
||||||
|
@ -222,7 +222,7 @@ type RevokeOptions struct {
|
||||||
//
|
//
|
||||||
// TODO: Add OCSP and CRL support.
|
// TODO: Add OCSP and CRL support.
|
||||||
func (a *Authority) Revoke(opts *RevokeOptions) error {
|
func (a *Authority) Revoke(opts *RevokeOptions) error {
|
||||||
errContext := context{
|
errContext := apiCtx{
|
||||||
"serialNumber": opts.Serial,
|
"serialNumber": opts.Serial,
|
||||||
"reasonCode": opts.ReasonCode,
|
"reasonCode": opts.ReasonCode,
|
||||||
"reason": opts.Reason,
|
"reason": opts.Reason,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
@ -103,7 +104,8 @@ func TestSign(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
|
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
extraOpts, err := a.Authorize(token)
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||||
|
extraOpts, err := a.Authorize(ctx, token)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type signTest struct {
|
type signTest struct {
|
||||||
|
@ -124,7 +126,7 @@ func TestSign(t *testing.T) {
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: &apiError{errors.New("sign: invalid certificate request"),
|
err: &apiError{errors.New("sign: invalid certificate request"),
|
||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
context{"csr": csr, "signOptions": signOpts},
|
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -138,7 +140,7 @@ func TestSign(t *testing.T) {
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: &apiError{errors.New("sign: invalid extra option type string"),
|
err: &apiError{errors.New("sign: invalid extra option type string"),
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
context{"csr": csr, "signOptions": signOpts},
|
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -153,7 +155,7 @@ func TestSign(t *testing.T) {
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"),
|
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"),
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
context{"csr": csr, "signOptions": signOpts},
|
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -168,7 +170,7 @@ func TestSign(t *testing.T) {
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: &apiError{errors.New("sign: error creating new leaf certificate"),
|
err: &apiError{errors.New("sign: error creating new leaf certificate"),
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
context{"csr": csr, "signOptions": signOpts},
|
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -185,7 +187,7 @@ func TestSign(t *testing.T) {
|
||||||
signOpts: _signOpts,
|
signOpts: _signOpts,
|
||||||
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"),
|
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"),
|
||||||
http.StatusUnauthorized,
|
http.StatusUnauthorized,
|
||||||
context{"csr": csr, "signOptions": _signOpts},
|
apiCtx{"csr": csr, "signOptions": _signOpts},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -200,7 +202,7 @@ func TestSign(t *testing.T) {
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
||||||
http.StatusUnauthorized,
|
http.StatusUnauthorized,
|
||||||
context{"csr": csr, "signOptions": signOpts},
|
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -227,7 +229,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: &apiError{errors.New("sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
|
err: &apiError{errors.New("sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
|
||||||
http.StatusUnauthorized,
|
http.StatusUnauthorized,
|
||||||
context{"csr": csr, "signOptions": signOpts},
|
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -238,7 +240,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG
|
||||||
storeCertificate: func(crt *x509.Certificate) error {
|
storeCertificate: func(crt *x509.Certificate) error {
|
||||||
return &apiError{errors.New("force"),
|
return &apiError{errors.New("force"),
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
context{"csr": csr, "signOptions": signOpts}}
|
apiCtx{"csr": csr, "signOptions": signOpts}}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return &signTest{
|
return &signTest{
|
||||||
|
@ -248,7 +250,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: &apiError{errors.New("sign: error storing certificate in db: force"),
|
err: &apiError{errors.New("sign: error storing certificate in db: force"),
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
context{"csr": csr, "signOptions": signOpts},
|
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -401,7 +403,7 @@ func TestRenew(t *testing.T) {
|
||||||
auth: _a,
|
auth: _a,
|
||||||
crt: crt,
|
crt: crt,
|
||||||
err: &apiError{errors.New("error renewing certificate from existing server certificate"),
|
err: &apiError{errors.New("error renewing certificate from existing server certificate"),
|
||||||
http.StatusInternalServerError, context{}},
|
http.StatusInternalServerError, apiCtx{}},
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
"fail-unauthorized": func() (*renewTest, error) {
|
"fail-unauthorized": func() (*renewTest, error) {
|
||||||
|
@ -596,7 +598,7 @@ func TestRevoke(t *testing.T) {
|
||||||
validAudience := []string{"https://test.ca.smallstep.com/revoke"}
|
validAudience := []string{"https://test.ca.smallstep.com/revoke"}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
getCtx := func() map[string]interface{} {
|
getCtx := func() map[string]interface{} {
|
||||||
return context{
|
return apiCtx{
|
||||||
"serialNumber": "sn",
|
"serialNumber": "sn",
|
||||||
"reasonCode": reasonCode,
|
"reasonCode": reasonCode,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
|
|
22
ca/client.go
22
ca/client.go
|
@ -373,6 +373,28 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
|
||||||
return &sign, nil
|
return &sign, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignSSH performs the SSH certificate sign request to the CA and returns the
|
||||||
|
// api.SignSSHResponse struct.
|
||||||
|
func (c *Client) SignSSH(req *api.SignSSHRequest) (*api.SignSSHResponse, error) {
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
|
}
|
||||||
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign-ssh"})
|
||||||
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readError(resp.Body)
|
||||||
|
}
|
||||||
|
var sign api.SignSSHResponse
|
||||||
|
if err := readJSON(resp.Body, &sign); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||||
|
}
|
||||||
|
return &sign, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Renew performs the renew request to the CA and returns the api.SignResponse
|
// Renew performs the renew request to the CA and returns the api.SignResponse
|
||||||
// struct.
|
// struct.
|
||||||
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
||||||
|
|
|
@ -280,7 +280,11 @@ func stringifyFlag(f cli.Flag) string {
|
||||||
usage := fv.FieldByName("Usage").String()
|
usage := fv.FieldByName("Usage").String()
|
||||||
placeholder := placeholderString.FindString(usage)
|
placeholder := placeholderString.FindString(usage)
|
||||||
if placeholder == "" {
|
if placeholder == "" {
|
||||||
placeholder = "<value>"
|
switch f.(type) {
|
||||||
|
case cli.BoolFlag, cli.BoolTFlag:
|
||||||
|
default:
|
||||||
|
placeholder = "<value>"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage
|
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue