forked from TrueCloudLab/certificates
commit
7726f5ec75
41 changed files with 2710 additions and 120 deletions
8
Gopkg.lock
generated
8
Gopkg.lock
generated
|
@ -262,15 +262,20 @@
|
|||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:5dd7da6df07f42194cb25d162b4b89664ed7b08d7d4334f6a288393d54b095ce"
|
||||
digest = "1:afc49fe39c8c591fc2c8ddc73adc4c69e67125dde6c58e24c91b3b0cf78602be"
|
||||
name = "golang.org/x/crypto"
|
||||
packages = [
|
||||
"cryptobyte",
|
||||
"cryptobyte/asn1",
|
||||
"curve25519",
|
||||
"ed25519",
|
||||
"ed25519/internal/edwards25519",
|
||||
"internal/chacha20",
|
||||
"internal/subtle",
|
||||
"ocsp",
|
||||
"pbkdf2",
|
||||
"poly1305",
|
||||
"ssh",
|
||||
"ssh/terminal",
|
||||
]
|
||||
pruneopts = "UT"
|
||||
|
@ -394,6 +399,7 @@
|
|||
"github.com/urfave/cli",
|
||||
"golang.org/x/crypto/ed25519",
|
||||
"golang.org/x/crypto/ocsp",
|
||||
"golang.org/x/crypto/ssh",
|
||||
"golang.org/x/net/http2",
|
||||
"gopkg.in/square/go-jose.v2",
|
||||
"gopkg.in/square/go-jose.v2/jwt",
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/dsa"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
|
@ -26,9 +27,10 @@ import (
|
|||
|
||||
// Authority is the interface implemented by a CA authority.
|
||||
type Authority interface {
|
||||
SSHAuthority
|
||||
// NOTE: Authorize will be deprecated in future releases. Please use the
|
||||
// context specific Authoirize[Sign|Revoke|etc.] methods.
|
||||
Authorize(ott string) ([]provisioner.SignOption, error)
|
||||
// context specific Authorize[Sign|Revoke|etc.] methods.
|
||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||
GetTLSOptions() *tlsutil.TLSOptions
|
||||
Root(shasum string) (*x509.Certificate, error)
|
||||
|
@ -249,6 +251,8 @@ func (h *caHandler) Route(r Router) {
|
|||
r.MethodFunc("GET", "/federation", h.Federation)
|
||||
// For compatibility with old code:
|
||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||
// SSH CA
|
||||
r.MethodFunc("POST", "/sign-ssh", h.SignSSH)
|
||||
}
|
||||
|
||||
// Health is an HTTP handler that returns the status of the server.
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -424,7 +425,7 @@ type mockProvisioner struct {
|
|||
getEncryptedKey func() (string, string, bool)
|
||||
init func(provisioner.Config) error
|
||||
authorizeRevoke func(ott string) error
|
||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
||||
authorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
authorizeRenewal func(*x509.Certificate) error
|
||||
}
|
||||
|
||||
|
@ -480,9 +481,9 @@ func (m *mockProvisioner) AuthorizeRevoke(ott string) error {
|
|||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockProvisioner) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||
func (m *mockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
if m.authorizeSign != nil {
|
||||
return m.authorizeSign(ott)
|
||||
return m.authorizeSign(ctx, ott)
|
||||
}
|
||||
return m.ret1.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
@ -501,6 +502,8 @@ type mockAuthority struct {
|
|||
getTLSOptions func() *tlsutil.TLSOptions
|
||||
root func(shasum string) (*x509.Certificate, error)
|
||||
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||
signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
|
@ -511,7 +514,7 @@ type mockAuthority struct {
|
|||
}
|
||||
|
||||
// TODO: remove once Authorize is deprecated.
|
||||
func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) {
|
||||
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return m.AuthorizeSign(ott)
|
||||
}
|
||||
|
||||
|
@ -543,6 +546,20 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Optio
|
|||
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
if m.signSSH != nil {
|
||||
return m.signSSH(key, opts, signOpts...)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if m.signSSHAddUser != nil {
|
||||
return m.signSSHAddUser(key, cert)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) {
|
||||
if m.renew != nil {
|
||||
return m.renew(cert)
|
||||
|
|
159
api/ssh.go
Normal file
159
api/ssh.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SSHAuthority is the interface implemented by a SSH CA authority.
|
||||
type SSHAuthority interface {
|
||||
SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
}
|
||||
|
||||
// SignSSHRequest is the request body of an SSH certificate request.
|
||||
type SignSSHRequest struct {
|
||||
PublicKey []byte `json:"publicKey"` //base64 encoded
|
||||
OTT string `json:"ott"`
|
||||
CertType string `json:"certType,omitempty"`
|
||||
Principals []string `json:"principals,omitempty"`
|
||||
ValidAfter TimeDuration `json:"validAfter,omitempty"`
|
||||
ValidBefore TimeDuration `json:"validBefore,omitempty"`
|
||||
AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"`
|
||||
}
|
||||
|
||||
// SignSSHResponse is the response object that returns the SSH certificate.
|
||||
type SignSSHResponse struct {
|
||||
Certificate SSHCertificate `json:"crt"`
|
||||
AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"`
|
||||
}
|
||||
|
||||
// SSHCertificate represents the response SSH certificate.
|
||||
type SSHCertificate struct {
|
||||
*ssh.Certificate `json:"omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface. Returns a quoted,
|
||||
// base64 encoded, openssh wire format version of the certificate.
|
||||
func (c SSHCertificate) MarshalJSON() ([]byte, error) {
|
||||
if c.Certificate == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
s := base64.StdEncoding.EncodeToString(c.Certificate.Marshal())
|
||||
return []byte(`"` + s + `"`), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface. The certificate is
|
||||
// expected to be a quoted, base64 encoded, openssh wire formatted block of bytes.
|
||||
func (c *SSHCertificate) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrap(err, "error decoding certificate")
|
||||
}
|
||||
if s == "" {
|
||||
c.Certificate = nil
|
||||
return nil
|
||||
}
|
||||
certData, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error decoding ssh certificate")
|
||||
}
|
||||
pub, err := ssh.ParsePublicKey(certData)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing ssh certificate")
|
||||
}
|
||||
cert, ok := pub.(*ssh.Certificate)
|
||||
if !ok {
|
||||
return errors.Errorf("error decoding ssh certificate: %T is not an *ssh.Certificate", pub)
|
||||
}
|
||||
c.Certificate = cert
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates the SignSSHRequest.
|
||||
func (s *SignSSHRequest) Validate() error {
|
||||
switch {
|
||||
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
|
||||
return errors.Errorf("unknown certType %s", s.CertType)
|
||||
case len(s.PublicKey) == 0:
|
||||
return errors.New("missing or empty publicKey")
|
||||
case len(s.OTT) == 0:
|
||||
return errors.New("missing or empty ott")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SignSSH is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignSSHRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, BadRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey")))
|
||||
return
|
||||
}
|
||||
|
||||
var addUserPublicKey ssh.PublicKey
|
||||
if body.AddUserPublicKey != nil {
|
||||
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
||||
if err != nil {
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
opts := provisioner.SSHOptions{
|
||||
CertType: body.CertType,
|
||||
Principals: body.Principals,
|
||||
ValidBefore: body.ValidBefore,
|
||||
ValidAfter: body.ValidAfter,
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, Unauthorized(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
|
||||
var addUserCertificate *SSHCertificate
|
||||
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
|
||||
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
|
||||
if err != nil {
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
addUserCertificate = &SSHCertificate{addUserCert}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
JSON(w, &SignSSHResponse{
|
||||
Certificate: SSHCertificate{cert},
|
||||
AddUserCertificate: addUserCertificate,
|
||||
})
|
||||
}
|
327
api/ssh_test.go
Normal file
327
api/ssh_test.go
Normal file
|
@ -0,0 +1,327 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
sshSignerKey = mustKey()
|
||||
sshUserKey = mustKey()
|
||||
sshHostKey = mustKey()
|
||||
)
|
||||
|
||||
func mustKey() *ecdsa.PrivateKey {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return priv
|
||||
}
|
||||
|
||||
func signSSHCertificate(cert *ssh.Certificate) error {
|
||||
signerKey, err := ssh.NewPublicKey(sshSignerKey.Public())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
signer, err := ssh.NewSignerFromSigner(sshSignerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.SignatureKey = signerKey
|
||||
data := cert.Marshal()
|
||||
data = data[:len(data)-4]
|
||||
sig, err := signer.Sign(rand.Reader, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.Signature = sig
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSignedUserCertificate() (*ssh.Certificate, error) {
|
||||
key, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := time.Now()
|
||||
cert := &ssh.Certificate{
|
||||
Nonce: []byte("1234567890"),
|
||||
Key: key,
|
||||
Serial: 1234567890,
|
||||
CertType: ssh.UserCert,
|
||||
KeyId: "user@localhost",
|
||||
ValidPrincipals: []string{"user"},
|
||||
ValidAfter: uint64(t.Unix()),
|
||||
ValidBefore: uint64(t.Add(time.Hour).Unix()),
|
||||
Permissions: ssh.Permissions{
|
||||
CriticalOptions: map[string]string{},
|
||||
Extensions: map[string]string{
|
||||
"permit-X11-forwarding": "",
|
||||
"permit-agent-forwarding": "",
|
||||
"permit-port-forwarding": "",
|
||||
"permit-pty": "",
|
||||
"permit-user-rc": "",
|
||||
},
|
||||
},
|
||||
Reserved: []byte{},
|
||||
}
|
||||
if err := signSSHCertificate(cert); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func getSignedHostCertificate() (*ssh.Certificate, error) {
|
||||
key, err := ssh.NewPublicKey(sshHostKey.Public())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := time.Now()
|
||||
cert := &ssh.Certificate{
|
||||
Nonce: []byte("1234567890"),
|
||||
Key: key,
|
||||
Serial: 1234567890,
|
||||
CertType: ssh.UserCert,
|
||||
KeyId: "internal.smallstep.com",
|
||||
ValidPrincipals: []string{"internal.smallstep.com"},
|
||||
ValidAfter: uint64(t.Unix()),
|
||||
ValidBefore: uint64(t.Add(time.Hour).Unix()),
|
||||
Permissions: ssh.Permissions{
|
||||
CriticalOptions: map[string]string{},
|
||||
Extensions: map[string]string{},
|
||||
},
|
||||
Reserved: []byte{},
|
||||
}
|
||||
if err := signSSHCertificate(cert); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func TestSSHCertificate_MarshalJSON(t *testing.T) {
|
||||
user, err := getSignedUserCertificate()
|
||||
assert.FatalError(t, err)
|
||||
host, err := getSignedHostCertificate()
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||
|
||||
type fields struct {
|
||||
Certificate *ssh.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"nil", fields{Certificate: nil}, []byte("null"), false},
|
||||
{"user", fields{Certificate: user}, []byte(`"` + userB64 + `"`), false},
|
||||
{"user", fields{Certificate: host}, []byte(`"` + hostB64 + `"`), false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := SSHCertificate{
|
||||
Certificate: tt.fields.Certificate,
|
||||
}
|
||||
got, err := c.MarshalJSON()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SSHCertificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SSHCertificate.MarshalJSON() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCertificate_UnmarshalJSON(t *testing.T) {
|
||||
user, err := getSignedUserCertificate()
|
||||
assert.FatalError(t, err)
|
||||
host, err := getSignedHostCertificate()
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||
keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal())
|
||||
|
||||
type args struct {
|
||||
data []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *ssh.Certificate
|
||||
wantErr bool
|
||||
}{
|
||||
{"null", args{[]byte(`null`)}, nil, false},
|
||||
{"empty", args{[]byte(`""`)}, nil, false},
|
||||
{"user", args{[]byte(`"` + userB64 + `"`)}, user, false},
|
||||
{"host", args{[]byte(`"` + hostB64 + `"`)}, host, false},
|
||||
{"bad-string", args{[]byte(userB64)}, nil, true},
|
||||
{"bad-base64", args{[]byte(`"this-is-not-base64"`)}, nil, true},
|
||||
{"bad-key", args{[]byte(`"bm90LWEta2V5"`)}, nil, true},
|
||||
{"bat-cert", args{[]byte(`"` + keyB64 + `"`)}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &SSHCertificate{}
|
||||
if err := c.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SSHCertificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.want, c.Certificate) {
|
||||
t.Errorf("SSHCertificate.UnmarshalJSON() got = %v, want %v\n", c.Certificate, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignSSHRequest_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
PublicKey []byte
|
||||
OTT string
|
||||
CertType string
|
||||
Principals []string
|
||||
ValidAfter TimeDuration
|
||||
ValidBefore TimeDuration
|
||||
AddUserPublicKey []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||
{"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||
{"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||
{"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||
{"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||
{"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||
{"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &SignSSHRequest{
|
||||
PublicKey: tt.fields.PublicKey,
|
||||
OTT: tt.fields.OTT,
|
||||
CertType: tt.fields.CertType,
|
||||
Principals: tt.fields.Principals,
|
||||
ValidAfter: tt.fields.ValidAfter,
|
||||
ValidBefore: tt.fields.ValidBefore,
|
||||
AddUserPublicKey: tt.fields.AddUserPublicKey,
|
||||
}
|
||||
if err := s.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SignSSH(t *testing.T) {
|
||||
user, err := getSignedUserCertificate()
|
||||
assert.FatalError(t, err)
|
||||
host, err := getSignedHostCertificate()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||
|
||||
userReq, err := json.Marshal(SignSSHRequest{
|
||||
PublicKey: user.Key.Marshal(),
|
||||
OTT: "ott",
|
||||
})
|
||||
assert.FatalError(t, err)
|
||||
hostReq, err := json.Marshal(SignSSHRequest{
|
||||
PublicKey: host.Key.Marshal(),
|
||||
OTT: "ott",
|
||||
})
|
||||
assert.FatalError(t, err)
|
||||
userAddReq, err := json.Marshal(SignSSHRequest{
|
||||
PublicKey: user.Key.Marshal(),
|
||||
OTT: "ott",
|
||||
AddUserPublicKey: user.Key.Marshal(),
|
||||
})
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
}
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
req []byte
|
||||
authErr error
|
||||
signCert *ssh.Certificate
|
||||
signErr error
|
||||
addUserCert *ssh.Certificate
|
||||
addUserErr error
|
||||
body []byte
|
||||
statusCode int
|
||||
}{
|
||||
{"ok-user", userReq, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated},
|
||||
{"ok-host", hostReq, nil, host, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated},
|
||||
{"ok-user-add", userAddReq, nil, user, nil, user, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated},
|
||||
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusUnauthorized},
|
||||
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
|
||||
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return []provisioner.SignOption{}, tt.authErr
|
||||
},
|
||||
signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
return tt.signCert, tt.signErr
|
||||
},
|
||||
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
return tt.addUserCert, tt.addUserErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com/sign-ssh", bytes.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
h.SignSSH(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.Root unexpected error = %v", err)
|
||||
}
|
||||
if tt.statusCode < http.StatusBadRequest {
|
||||
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||
t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.body)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,12 +1,14 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
|
@ -20,6 +22,8 @@ type Authority struct {
|
|||
config *Config
|
||||
rootX509Certs []*x509.Certificate
|
||||
intermediateIdentity *x509util.Identity
|
||||
sshCAUserCertSignKey crypto.Signer
|
||||
sshCAHostCertSignKey crypto.Signer
|
||||
validateOnce bool
|
||||
certificates *sync.Map
|
||||
startTime time.Time
|
||||
|
@ -117,6 +121,22 @@ func (a *Authority) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Decrypt and load SSH keys
|
||||
if a.config.SSH != nil {
|
||||
if a.config.SSH.HostKey != "" {
|
||||
a.sshCAHostCertSignKey, err = parseCryptoSigner(a.config.SSH.HostKey, a.config.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if a.config.SSH.UserKey != "" {
|
||||
a.sshCAUserCertSignKey, err = parseCryptoSigner(a.config.SSH.UserKey, a.config.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store all the provisioners
|
||||
for _, p := range a.config.AuthorityConfig.Provisioners {
|
||||
if err := a.provisioners.Store(p); err != nil {
|
||||
|
@ -143,3 +163,19 @@ func (a *Authority) GetDatabase() db.AuthDB {
|
|||
func (a *Authority) Shutdown() error {
|
||||
return a.db.Shutdown()
|
||||
}
|
||||
|
||||
func parseCryptoSigner(filename, password string) (crypto.Signer, error) {
|
||||
var opts []pemutil.Options
|
||||
if password != "" {
|
||||
opts = append(opts, pemutil.WithPassword([]byte(password)))
|
||||
}
|
||||
key, err := pemutil.Read(filename, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signer, ok := key.(crypto.Signer)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("key %s of type %T cannot be used for signing operations", filename, key)
|
||||
}
|
||||
return signer, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
@ -72,33 +73,51 @@ func (a *Authority) authorizeToken(ott string) (provisioner.Interface, error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
// Authorize is a passthrough to AuthorizeSign.
|
||||
// NOTE: Authorize will be deprecated in a future release. Please use the
|
||||
// context specific Authorize[Sign|Revoke|etc.] going forwards.
|
||||
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
|
||||
return a.AuthorizeSign(ott)
|
||||
// Authorize grabs the method from the context and authorizes a signature
|
||||
// request by validating the one-time-token.
|
||||
func (a *Authority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
var errContext = apiCtx{"ott": ott}
|
||||
switch m := provisioner.MethodFromContext(ctx); m {
|
||||
case provisioner.SignMethod:
|
||||
return a.authorizeSign(ctx, ott)
|
||||
case provisioner.SignSSHMethod:
|
||||
if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{errors.New("authorize: ssh signing is not enabled"), http.StatusNotImplemented, errContext}
|
||||
}
|
||||
return a.authorizeSign(ctx, ott)
|
||||
case provisioner.RevokeMethod:
|
||||
return nil, &apiError{errors.New("authorize: revoke method is not supported"), http.StatusInternalServerError, errContext}
|
||||
default:
|
||||
return nil, &apiError{errors.Errorf("authorize: method %d is not supported", m), http.StatusInternalServerError, errContext}
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizeSign authorizes a signature request by validating and authenticating
|
||||
// a OTT that must be sent w/ the request.
|
||||
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||
var errContext = context{"ott": ott}
|
||||
|
||||
// authorizeSign loads the provisioner from the token, checks that it has not
|
||||
// been used again and calls the provisioner AuthorizeSign method. Returns a
|
||||
// list of methods to apply to the signing flow.
|
||||
func (a *Authority) authorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
var errContext = apiCtx{"ott": ott}
|
||||
p, err := a.authorizeToken(ott)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
|
||||
// Call the provisioner AuthorizeSign method to apply provisioner specific
|
||||
// auth claims and get the signing options.
|
||||
opts, err := p.AuthorizeSign(ott)
|
||||
opts, err := p.AuthorizeSign(ctx, ott)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "authorizeSign"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// AuthorizeSign authorizes a signature request by validating and authenticating
|
||||
// a OTT that must be sent w/ the request.
|
||||
//
|
||||
// NOTE: This method is deprecated and should not be used. We make it available
|
||||
// in the short term os as not to break existing clients.
|
||||
func (a *Authority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
return a.Authorize(ctx, ott)
|
||||
}
|
||||
|
||||
// authorizeRevoke authorizes a revocation request by validating and authenticating
|
||||
// the RevokeOptions POSTed with the request.
|
||||
// Returns a tuple of the provisioner ID and error, if one occurred.
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
|
@ -72,7 +75,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
auth: a,
|
||||
ott: "foo",
|
||||
err: &apiError{errors.New("authorizeToken: error parsing token"),
|
||||
http.StatusUnauthorized, context{"ott": "foo"}},
|
||||
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
|
||||
}
|
||||
},
|
||||
"fail/prehistoric-token": func(t *testing.T) *authorizeTest {
|
||||
|
@ -91,7 +94,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
auth: a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorizeToken: token issued before the bootstrap of certificate authority"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"fail/provisioner-not-found": func(t *testing.T) *authorizeTest {
|
||||
|
@ -113,7 +116,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
auth: a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorizeToken: provisioner not found or invalid audience (https://test.ca.smallstep.com/revoke)"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"ok/simpledb": func(t *testing.T) *authorizeTest {
|
||||
|
@ -150,7 +153,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
auth: _a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorizeToken: token already used"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"ok/mockNoSQLDB": func(t *testing.T) *authorizeTest {
|
||||
|
@ -198,7 +201,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
auth: _a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorizeToken: failed when checking if token already used: force"),
|
||||
http.StatusInternalServerError, context{"ott": raw}},
|
||||
http.StatusInternalServerError, apiCtx{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest {
|
||||
|
@ -223,7 +226,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
auth: _a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorizeToken: token already used"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -388,7 +391,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
|
|||
auth: a,
|
||||
ott: "foo",
|
||||
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
|
||||
http.StatusUnauthorized, context{"ott": "foo"}},
|
||||
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
|
||||
}
|
||||
},
|
||||
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
|
||||
|
@ -406,7 +409,7 @@ func TestAuthority_AuthorizeSign(t *testing.T) {
|
|||
auth: a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *authorizeTest {
|
||||
|
@ -480,7 +483,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
auth: a,
|
||||
ott: "foo",
|
||||
err: &apiError{errors.New("authorizeSign: authorizeToken: error parsing token"),
|
||||
http.StatusUnauthorized, context{"ott": "foo"}},
|
||||
http.StatusUnauthorized, apiCtx{"ott": "foo"}},
|
||||
}
|
||||
},
|
||||
"fail/invalid-subject": func(t *testing.T) *authorizeTest {
|
||||
|
@ -498,7 +501,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
auth: a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorizeSign: token subject cannot be empty"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
http.StatusUnauthorized, apiCtx{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *authorizeTest {
|
||||
|
@ -522,8 +525,8 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
for name, genTestCase := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := genTestCase(t)
|
||||
|
||||
got, err := tc.auth.Authorize(tc.ott)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
got, err := tc.auth.Authorize(ctx, tc.ott)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Nil(t, got)
|
||||
|
@ -573,7 +576,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
|
|||
auth: a,
|
||||
crt: fooCrt,
|
||||
err: &apiError{errors.New("renew: force"),
|
||||
http.StatusInternalServerError, context{"serialNumber": "102012593071130646873265215610956555026"}},
|
||||
http.StatusInternalServerError, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
|
||||
}
|
||||
},
|
||||
"fail/revoked": func(t *testing.T) *authorizeTest {
|
||||
|
@ -587,7 +590,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
|
|||
auth: a,
|
||||
crt: fooCrt,
|
||||
err: &apiError{errors.New("renew: certificate has been revoked"),
|
||||
http.StatusUnauthorized, context{"serialNumber": "102012593071130646873265215610956555026"}},
|
||||
http.StatusUnauthorized, apiCtx{"serialNumber": "102012593071130646873265215610956555026"}},
|
||||
}
|
||||
},
|
||||
"fail/load-provisioner": func(t *testing.T) *authorizeTest {
|
||||
|
@ -601,7 +604,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
|
|||
auth: a,
|
||||
crt: otherCrt,
|
||||
err: &apiError{errors.New("renew: provisioner not found"),
|
||||
http.StatusUnauthorized, context{"serialNumber": "41633491264736369593451462439668497527"}},
|
||||
http.StatusUnauthorized, apiCtx{"serialNumber": "41633491264736369593451462439668497527"}},
|
||||
}
|
||||
},
|
||||
"fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest {
|
||||
|
@ -616,7 +619,7 @@ func TestAuthority_authorizeRenewal(t *testing.T) {
|
|||
auth: a,
|
||||
crt: renewDisabledCrt,
|
||||
err: &apiError{errors.New("renew: renew is disabled for provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
|
||||
http.StatusUnauthorized, context{"serialNumber": "119772236532068856521070735128919532568"}},
|
||||
http.StatusUnauthorized, apiCtx{"serialNumber": "119772236532068856521070735128919532568"}},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *authorizeTest {
|
||||
|
|
|
@ -28,11 +28,19 @@ var (
|
|||
Renegotiation: false,
|
||||
}
|
||||
defaultDisableRenewal = false
|
||||
defaultEnableSSHCA = false
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &provisioner.Duration{Duration: 4 * time.Hour},
|
||||
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &defaultEnableSSHCA,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -44,6 +52,7 @@ type Config struct {
|
|||
IntermediateKey string `json:"key"`
|
||||
Address string `json:"address"`
|
||||
DNSNames []string `json:"dnsNames"`
|
||||
SSH *SSHConfig `json:"ssh,omitempty"`
|
||||
Logger json.RawMessage `json:"logger,omitempty"`
|
||||
DB *db.Config `json:"db,omitempty"`
|
||||
Monitoring json.RawMessage `json:"monitoring,omitempty"`
|
||||
|
@ -92,6 +101,14 @@ func (c *AuthConfig) Validate(audiences provisioner.Audiences) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SSHConfig contains the user and host keys.
|
||||
type SSHConfig struct {
|
||||
HostKey string `json:"hostKey"`
|
||||
UserKey string `json:"userKey"`
|
||||
AddUserPrincipal string `json:"addUserPrincipal"`
|
||||
AddUserCommand string `json:"addUserCommand"`
|
||||
}
|
||||
|
||||
// LoadConfiguration parses the given filename in JSON format and returns the
|
||||
// configuration struct.
|
||||
func LoadConfiguration(filename string) (*Config, error) {
|
||||
|
|
|
@ -4,13 +4,13 @@ import (
|
|||
"net/http"
|
||||
)
|
||||
|
||||
type context map[string]interface{}
|
||||
type apiCtx map[string]interface{}
|
||||
|
||||
// Error implements the api.Error interface and adds context to error messages.
|
||||
type apiError struct {
|
||||
err error
|
||||
code int
|
||||
context context
|
||||
context apiCtx
|
||||
}
|
||||
|
||||
// Cause implements the errors.Causer interface and returns the original error.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
|
@ -266,13 +267,21 @@ func (p *AWS) Init(config Config) (err error) {
|
|||
|
||||
// AuthorizeSign validates the given token and returns the sign options that
|
||||
// will be used on certificate creation.
|
||||
func (p *AWS) AuthorizeSign(token string) ([]SignOption, error) {
|
||||
func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
payload, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
doc := payload.document
|
||||
|
||||
// Check for the sign ssh method, default to sign X.509
|
||||
if m := MethodFromContext(ctx); m == SignSSHMethod {
|
||||
if p.claimer.IsSSHCAEnabled() == false {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
return p.authorizeSSHSign(payload)
|
||||
}
|
||||
|
||||
doc := payload.document
|
||||
// Enforce known CN and default DNS and IP if configured.
|
||||
// By default we'll accept the CN and SANs in the CSR.
|
||||
// There's no way to trust them other than TOFU.
|
||||
|
@ -433,3 +442,35 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
payload.document = doc
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *AWS) authorizeSSHSign(claims *awsPayload) ([]SignOption, error) {
|
||||
doc := claims.document
|
||||
|
||||
signOptions := []SignOption{
|
||||
// set the key id to the token subject
|
||||
sshCertificateKeyIDModifier(claims.Subject),
|
||||
}
|
||||
|
||||
// Default to host + known IPs/hostnames
|
||||
defaults := SSHOptions{
|
||||
CertType: SSHHostCert,
|
||||
Principals: []string{
|
||||
doc.PrivateIP,
|
||||
fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region),
|
||||
},
|
||||
}
|
||||
// Validate user options
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
// Set defaults if not given as user options
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||
|
||||
return append(signOptions,
|
||||
// set the default extensions
|
||||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{p.claimer},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
|
@ -347,7 +349,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.aws.AuthorizeSign(tt.args.token)
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -357,6 +360,84 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAWS_AuthorizeSign_SSH(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
|
||||
p1, srv, err := generateAWSWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
key, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedHostOptions := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
expectedHostOptionsIP := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"127.0.0.1"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
expectedHostOptionsHostname := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
aws *AWS
|
||||
args args
|
||||
expected *SSHOptions
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false},
|
||||
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}}, expectedHostOptionsIP, false, false},
|
||||
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptionsHostname, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}}, nil, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
if tt.wantSignErr {
|
||||
assert.Nil(t, cert)
|
||||
} else {
|
||||
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestAWS_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
|
@ -209,7 +210,7 @@ func (p *Azure) Init(config Config) (err error) {
|
|||
|
||||
// AuthorizeSign validates the given token and returns the sign options that
|
||||
// will be used on certificate creation.
|
||||
func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
||||
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
|
@ -264,6 +265,14 @@ func (p *Azure) AuthorizeSign(token string) ([]SignOption, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Check for the sign ssh method, default to sign X.509
|
||||
if m := MethodFromContext(ctx); m == SignSSHMethod {
|
||||
if p.claimer.IsSSHCAEnabled() == false {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
return p.authorizeSSHSign(claims, name)
|
||||
}
|
||||
|
||||
// Enforce known common name and default DNS if configured.
|
||||
// By default we'll accept the CN and SANs in the CSR.
|
||||
// There's no way to trust them other than TOFU.
|
||||
|
@ -296,6 +305,33 @@ func (p *Azure) AuthorizeRevoke(token string) error {
|
|||
return errors.New("revoke is not supported on a Azure provisioner")
|
||||
}
|
||||
|
||||
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *Azure) authorizeSSHSign(claims azurePayload, name string) ([]SignOption, error) {
|
||||
signOptions := []SignOption{
|
||||
// set the key id to the token subject
|
||||
sshCertificateKeyIDModifier(name),
|
||||
}
|
||||
|
||||
// Default to host + known hostnames
|
||||
defaults := SSHOptions{
|
||||
CertType: SSHHostCert,
|
||||
Principals: []string{name},
|
||||
}
|
||||
// Validate user options
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
// Set defaults if not given as user options
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||
|
||||
return append(signOptions,
|
||||
// set the default extensions
|
||||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{p.claimer},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
// assertConfig initializes the config if it has not been initialized
|
||||
func (p *Azure) assertConfig() {
|
||||
if p.config == nil {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
|
@ -295,7 +297,8 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.azure.AuthorizeSign(tt.args.token)
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -305,6 +308,75 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAzure_AuthorizeSign_SSH(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
|
||||
p1, srv, err := generateAzureWithServer()
|
||||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
t1, err := p1.GetIdentityToken("subject", "caURL")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
key, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedHostOptions := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"virtualMachine"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
azure *Azure
|
||||
args args
|
||||
expected *SSHOptions
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}}, nil, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
if tt.wantSignErr {
|
||||
assert.Nil(t, cert)
|
||||
} else {
|
||||
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzure_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
|
|
|
@ -8,10 +8,19 @@ import (
|
|||
|
||||
// Claims so that individual provisioners can override global claims.
|
||||
type Claims struct {
|
||||
// TLS CA properties
|
||||
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
||||
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
||||
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
||||
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||
// SSH CA properties
|
||||
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
|
||||
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
|
||||
DefaultUserSSHDur *Duration `json:"defaultUserSSHCertDuration,omitempty"`
|
||||
MinHostSSHDur *Duration `json:"minHostSSHCertDuration,omitempty"`
|
||||
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
|
||||
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
|
||||
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
|
||||
}
|
||||
|
||||
// Claimer is the type that controls claims. It provides an interface around the
|
||||
|
@ -30,11 +39,19 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
|
|||
// Claims returns the merge of the inner and global claims.
|
||||
func (c *Claimer) Claims() Claims {
|
||||
disableRenewal := c.IsDisableRenewal()
|
||||
enableSSHCA := c.IsSSHCAEnabled()
|
||||
return Claims{
|
||||
MinTLSDur: &Duration{c.MinTLSCertDuration()},
|
||||
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
||||
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
||||
DisableRenewal: &disableRenewal,
|
||||
MinTLSDur: &Duration{c.MinTLSCertDuration()},
|
||||
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
||||
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
||||
DisableRenewal: &disableRenewal,
|
||||
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
|
||||
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
|
||||
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
|
||||
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
|
||||
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
|
||||
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
|
||||
EnableSSHCA: &enableSSHCA,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,6 +95,76 @@ func (c *Claimer) IsDisableRenewal() bool {
|
|||
return *c.claims.DisableRenewal
|
||||
}
|
||||
|
||||
// DefaultUserSSHCertDuration returns the default SSH user cert duration for the
|
||||
// provisioner. If the default is not set within the provisioner, then the
|
||||
// global default from the authority configuration will be used.
|
||||
func (c *Claimer) DefaultUserSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.DefaultUserSSHDur == nil {
|
||||
return c.global.DefaultUserSSHDur.Duration
|
||||
}
|
||||
return c.claims.DefaultUserSSHDur.Duration
|
||||
}
|
||||
|
||||
// MinUserSSHCertDuration returns the minimum SSH user cert duration for the
|
||||
// provisioner. If the minimum is not set within the provisioner, then the
|
||||
// global minimum from the authority configuration will be used.
|
||||
func (c *Claimer) MinUserSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MinUserSSHDur == nil {
|
||||
return c.global.MinUserSSHDur.Duration
|
||||
}
|
||||
return c.claims.MinUserSSHDur.Duration
|
||||
}
|
||||
|
||||
// MaxUserSSHCertDuration returns the maximum SSH user cert duration for the
|
||||
// provisioner. If the maximum is not set within the provisioner, then the
|
||||
// global maximum from the authority configuration will be used.
|
||||
func (c *Claimer) MaxUserSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MaxUserSSHDur == nil {
|
||||
return c.global.MaxUserSSHDur.Duration
|
||||
}
|
||||
return c.claims.MaxUserSSHDur.Duration
|
||||
}
|
||||
|
||||
// DefaultHostSSHCertDuration returns the default SSH host cert duration for the
|
||||
// provisioner. If the default is not set within the provisioner, then the
|
||||
// global default from the authority configuration will be used.
|
||||
func (c *Claimer) DefaultHostSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.DefaultHostSSHDur == nil {
|
||||
return c.global.DefaultHostSSHDur.Duration
|
||||
}
|
||||
return c.claims.DefaultHostSSHDur.Duration
|
||||
}
|
||||
|
||||
// MinHostSSHCertDuration returns the minimum SSH host cert duration for the
|
||||
// provisioner. If the minimum is not set within the provisioner, then the
|
||||
// global minimum from the authority configuration will be used.
|
||||
func (c *Claimer) MinHostSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MinHostSSHDur == nil {
|
||||
return c.global.MinHostSSHDur.Duration
|
||||
}
|
||||
return c.claims.MinHostSSHDur.Duration
|
||||
}
|
||||
|
||||
// MaxHostSSHCertDuration returns the maximum SSH Host cert duration for the
|
||||
// provisioner. If the maximum is not set within the provisioner, then the
|
||||
// global maximum from the authority configuration will be used.
|
||||
func (c *Claimer) MaxHostSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MaxHostSSHDur == nil {
|
||||
return c.global.MaxHostSSHDur.Duration
|
||||
}
|
||||
return c.claims.MaxHostSSHDur.Duration
|
||||
}
|
||||
|
||||
// IsSSHCAEnabled returns if the SSH CA is enabled for the provisioner. If the
|
||||
// property is not set within the provisioner, then the global value from the
|
||||
// authority configuration will be used.
|
||||
func (c *Claimer) IsSSHCAEnabled() bool {
|
||||
if c.claims == nil || c.claims.EnableSSHCA == nil {
|
||||
return *c.global.EnableSSHCA
|
||||
}
|
||||
return *c.claims.EnableSSHCA
|
||||
}
|
||||
|
||||
// Validate validates and modifies the Claims with default values.
|
||||
func (c *Claimer) Validate() error {
|
||||
var (
|
||||
|
|
|
@ -2,6 +2,7 @@ package provisioner
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
|
@ -205,13 +206,21 @@ func (p *GCP) Init(config Config) error {
|
|||
|
||||
// AuthorizeSign validates the given token and returns the sign options that
|
||||
// will be used on certificate creation.
|
||||
func (p *GCP) AuthorizeSign(token string) ([]SignOption, error) {
|
||||
func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ce := claims.Google.ComputeEngine
|
||||
|
||||
// Check for the sign ssh method, default to sign X.509
|
||||
if m := MethodFromContext(ctx); m == SignSSHMethod {
|
||||
if p.claimer.IsSSHCAEnabled() == false {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
return p.authorizeSSHSign(claims)
|
||||
}
|
||||
|
||||
ce := claims.Google.ComputeEngine
|
||||
// Enforce known common name and default DNS if configured.
|
||||
// By default we we'll accept the CN and SANs in the CSR.
|
||||
// There's no way to trust them other than TOFU.
|
||||
|
@ -345,3 +354,35 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *GCP) authorizeSSHSign(claims *gcpPayload) ([]SignOption, error) {
|
||||
ce := claims.Google.ComputeEngine
|
||||
|
||||
signOptions := []SignOption{
|
||||
// set the key id to the token subject
|
||||
sshCertificateKeyIDModifier(ce.InstanceName),
|
||||
}
|
||||
|
||||
// Default to host + known hostnames
|
||||
defaults := SSHOptions{
|
||||
CertType: SSHHostCert,
|
||||
Principals: []string{
|
||||
fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID),
|
||||
fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID),
|
||||
},
|
||||
}
|
||||
// Validate user options
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
// Set defaults if not given as user options
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||
|
||||
return append(signOptions,
|
||||
// set the default extensions
|
||||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{p.claimer},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
|
@ -330,7 +332,8 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.gcp.AuthorizeSign(tt.args.token)
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -340,6 +343,87 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGCP_AuthorizeSign_SSH(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
|
||||
p1, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := generateGCPToken(p1.ServiceAccounts[0],
|
||||
"https://accounts.google.com", p1.GetID(),
|
||||
"instance-id", "instance-name", "project-id", "zone",
|
||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
key, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedHostOptions := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
expectedHostOptionsPrincipal1 := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"instance-name.c.project-id.internal"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
expectedHostOptionsPrincipal2 := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"instance-name.zone.c.project-id.internal"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
gcp *GCP
|
||||
args args
|
||||
expected *SSHOptions
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false},
|
||||
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}}, expectedHostOptionsPrincipal1, false, false},
|
||||
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}}, expectedHostOptionsPrincipal2, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}}, nil, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
if tt.wantSignErr {
|
||||
assert.Nil(t, cert)
|
||||
} else {
|
||||
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGCP_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"time"
|
||||
|
||||
|
@ -12,7 +13,12 @@ import (
|
|||
// jwtPayload extends jwt.Claims with step attributes.
|
||||
type jwtPayload struct {
|
||||
jose.Claims
|
||||
SANs []string `json:"sans,omitempty"`
|
||||
SANs []string `json:"sans,omitempty"`
|
||||
Step *stepPayload `json:"step,omitempty"`
|
||||
}
|
||||
|
||||
type stepPayload struct {
|
||||
SSH *SSHOptions `json:"ssh,omitempty"`
|
||||
}
|
||||
|
||||
// JWK is the default provisioner, an entity that can sign tokens necessary for
|
||||
|
@ -129,11 +135,20 @@ func (p *JWK) AuthorizeRevoke(token string) error {
|
|||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (p *JWK) AuthorizeSign(token string) ([]SignOption, error) {
|
||||
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for SSH token
|
||||
if claims.Step != nil && claims.Step.SSH != nil {
|
||||
if p.claimer.IsSSHCAEnabled() == false {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
return p.authorizeSSHSign(claims)
|
||||
}
|
||||
|
||||
// NOTE: This is for backwards compatibility with older versions of cli
|
||||
// and certificates. Older versions added the token subject as the only SAN
|
||||
// in a CSR by default.
|
||||
|
@ -161,3 +176,41 @@ func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *JWK) authorizeSSHSign(claims *jwtPayload) ([]SignOption, error) {
|
||||
t := now()
|
||||
opts := claims.Step.SSH
|
||||
signOptions := []SignOption{
|
||||
// validates user's SSHOptions with the ones in the token
|
||||
sshCertificateOptionsValidator(*opts),
|
||||
// set the key id to the token subject
|
||||
sshCertificateKeyIDModifier(claims.Subject),
|
||||
}
|
||||
|
||||
// Add modifiers from custom claims
|
||||
if opts.CertType != "" {
|
||||
signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType))
|
||||
}
|
||||
if len(opts.Principals) > 0 {
|
||||
signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals))
|
||||
}
|
||||
if !opts.ValidAfter.IsZero() {
|
||||
signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix()))
|
||||
}
|
||||
if !opts.ValidBefore.IsZero() {
|
||||
signOptions = append(signOptions, sshCertificateValidBeforeModifier(opts.ValidBefore.RelativeTime(t).Unix()))
|
||||
}
|
||||
|
||||
// Default to a user certificate with no principals if not set
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier{CertType: SSHUserCert})
|
||||
|
||||
return append(signOptions,
|
||||
// set the default extensions
|
||||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{p.claimer},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"strings"
|
||||
|
@ -13,11 +15,19 @@ import (
|
|||
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
defaultEnableSSHCA = true
|
||||
globalProvisionerClaims = Claims{
|
||||
MinTLSDur: &Duration{5 * time.Minute},
|
||||
MaxTLSDur: &Duration{24 * time.Hour},
|
||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
MinTLSDur: &Duration{5 * time.Minute},
|
||||
MaxTLSDur: &Duration{24 * time.Hour},
|
||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &Duration{Duration: 4 * time.Hour},
|
||||
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &defaultEnableSSHCA,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -259,7 +269,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got, err := tt.prov.AuthorizeSign(tt.args.token); err != nil {
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
|
@ -318,3 +329,201 @@ func TestJWK_AuthorizeRenewal(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_AuthorizeSign_SSH(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
iss, aud := p1.Name, testAudiences.Sign[0]
|
||||
|
||||
t1, err := generateSimpleSSHUserToken(iss, aud, jwk)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t2, err := generateSimpleSSHHostToken(iss, aud, jwk)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// invalid signature
|
||||
failSig := t1[0 : len(t1)-2]
|
||||
|
||||
key, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedUserOptions := &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||
}
|
||||
expectedHostOptions := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"smallstep.com"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
expected *SSHOptions
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"user", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false},
|
||||
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false},
|
||||
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"host", p1, args{t2, SSHOptions{}}, expectedHostOptions, false, false},
|
||||
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
|
||||
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
|
||||
{"fail-signature", p1, args{failSig, SSHOptions{}}, nil, true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
if tt.wantSignErr {
|
||||
assert.Nil(t, cert)
|
||||
} else {
|
||||
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
sub, iss, aud, iat := "subject@smallstep.com", p1.Name, testAudiences.Sign[0], time.Now()
|
||||
|
||||
key, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedUserOptions := &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||
}
|
||||
expectedHostOptions := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"smallstep.com"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
type args struct {
|
||||
sub, iss, aud string
|
||||
iat time.Time
|
||||
tokSSHOpts *SSHOptions
|
||||
userSSHOpts *SSHOptions
|
||||
jwk *jose.JSONWebKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
expected *SSHOptions
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok-user", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, expectedUserOptions, false, false},
|
||||
{"ok-host", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, &SSHOptions{}, jwk}, expectedHostOptions, false, false},
|
||||
{"ok-user-opts", p1, args{sub, iss, aud, iat, &SSHOptions{}, &SSHOptions{CertType: "user", Principals: []string{"name"}}, jwk}, expectedUserOptions, false, false},
|
||||
{"ok-host-opts", p1, args{sub, iss, aud, iat, &SSHOptions{}, &SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, jwk}, expectedHostOptions, false, false},
|
||||
{"ok-user-mixed", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user"}, &SSHOptions{Principals: []string{"name"}}, jwk}, expectedUserOptions, false, false},
|
||||
{"ok-host-mixed", p1, args{sub, iss, aud, iat, &SSHOptions{Principals: []string{"smallstep.com"}}, &SSHOptions{CertType: "host"}, jwk}, expectedHostOptions, false, false},
|
||||
{"ok-user-validAfter", p1, args{sub, iss, aud, iat, &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"},
|
||||
}, &SSHOptions{
|
||||
ValidAfter: NewTimeDuration(tm.Add(-time.Hour)),
|
||||
}, jwk}, &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(-time.Hour)), ValidBefore: NewTimeDuration(tm.Add(userDuration - time.Hour)),
|
||||
}, false, false},
|
||||
{"ok-user-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"},
|
||||
}, &SSHOptions{
|
||||
ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||
}, jwk}, &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||
}, false, false},
|
||||
{"ok-user-validAfter-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"},
|
||||
}, &SSHOptions{
|
||||
ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||
}, jwk}, &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||
}, false, false},
|
||||
{"ok-user-match", p1, args{sub, iss, aud, iat, &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)),
|
||||
}, &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)),
|
||||
}, jwk}, &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)),
|
||||
}, false, false},
|
||||
{"fail-certType", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{CertType: "host"}, jwk}, nil, false, true},
|
||||
{"fail-principals", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{Principals: []string{"root"}}, jwk}, nil, false, true},
|
||||
{"fail-validAfter", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm)}, &SSHOptions{ValidAfter: NewTimeDuration(tm.Add(time.Hour))}, jwk}, nil, false, true},
|
||||
{"fail-validBefore", p1, args{sub, iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}, ValidBefore: NewTimeDuration(tm.Add(time.Hour))}, &SSHOptions{ValidBefore: NewTimeDuration(tm.Add(10 * time.Hour))}, jwk}, nil, false, true},
|
||||
{"fail-subject", p1, args{"", iss, aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||
{"fail-issuer", p1, args{sub, "invalid", aud, iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||
{"fail-audience", p1, args{sub, iss, "invalid", iat, &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||
{"fail-expired", p1, args{sub, iss, aud, iat.Add(-6 * time.Minute), &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||
{"fail-notBefore", p1, args{sub, iss, aud, iat.Add(5 * time.Minute), &SSHOptions{CertType: "user", Principals: []string{"name"}}, &SSHOptions{}, jwk}, nil, true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk)
|
||||
assert.FatalError(t, err)
|
||||
if got, err := tt.prov.AuthorizeSign(ctx, token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
} else if !tt.wantErr && assert.NotNil(t, got) {
|
||||
var opts SSHOptions
|
||||
if tt.args.userSSHOpts != nil {
|
||||
opts = *tt.args.userSSHOpts
|
||||
}
|
||||
cert, err := signSSHCertificate(key.Public().Key, opts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
if tt.wantSignErr {
|
||||
assert.Nil(t, cert)
|
||||
} else {
|
||||
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
34
authority/provisioner/method.go
Normal file
34
authority/provisioner/method.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Method indicates the action to action that we will perform, it's used as part
|
||||
// of the context in the call to authorize. It defaults to Sing.
|
||||
type Method int
|
||||
|
||||
// The key to save the Method in the context.
|
||||
type methodKey struct{}
|
||||
|
||||
const (
|
||||
// SignMethod is the method used to sign X.509 certificates.
|
||||
SignMethod Method = iota
|
||||
// SignSSHMethod is the method used to sign SSH certificate.
|
||||
SignSSHMethod
|
||||
// RevokeMethod is the method used to revoke X.509 certificates.
|
||||
RevokeMethod
|
||||
)
|
||||
|
||||
// NewContextWithMethod creates a new context from ctx and attaches method to
|
||||
// it.
|
||||
func NewContextWithMethod(ctx context.Context, method Method) context.Context {
|
||||
return context.WithValue(ctx, methodKey{}, method)
|
||||
}
|
||||
|
||||
// MethodFromContext returns the Method saved in ctx. Returns Sign if the given
|
||||
// context has no Method associated with it.
|
||||
func MethodFromContext(ctx context.Context) Method {
|
||||
m, _ := ctx.Value(methodKey{}).(Method)
|
||||
return m
|
||||
}
|
|
@ -1,6 +1,9 @@
|
|||
package provisioner
|
||||
|
||||
import "crypto/x509"
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
// noop provisioners is a provisioner that accepts anything.
|
||||
type noop struct{}
|
||||
|
@ -28,7 +31,7 @@ func (p *noop) Init(config Config) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *noop) AuthorizeSign(token string) ([]SignOption, error) {
|
||||
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
return []SignOption{}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
|
||||
|
@ -21,7 +22,8 @@ func Test_noop(t *testing.T) {
|
|||
assert.Equals(t, "", key)
|
||||
assert.Equals(t, false, ok)
|
||||
|
||||
sigOptions, err := p.AuthorizeSign("foo")
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
sigOptions, err := p.AuthorizeSign(ctx, "foo")
|
||||
assert.Equals(t, []SignOption{}, sigOptions)
|
||||
assert.Equals(t, nil, err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
@ -259,12 +260,29 @@ func (o *OIDC) AuthorizeRevoke(token string) error {
|
|||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (o *OIDC) AuthorizeSign(token string) ([]SignOption, error) {
|
||||
func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := o.authorizeToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for the sign ssh method, default to sign X.509
|
||||
if m := MethodFromContext(ctx); m == SignSSHMethod {
|
||||
if o.claimer.IsSSHCAEnabled() == false {
|
||||
return nil, errors.Errorf("ssh ca is disabled for provisioner %s", o.GetID())
|
||||
}
|
||||
return o.authorizeSSHSign(claims)
|
||||
}
|
||||
|
||||
// Admins should be able to authorize any SAN
|
||||
if o.IsAdmin(claims.Email) {
|
||||
return []SignOption{
|
||||
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
so := []SignOption{
|
||||
defaultPublicKeyValidator{},
|
||||
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
|
||||
|
@ -287,6 +305,42 @@ func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// authorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (o *OIDC) authorizeSSHSign(claims *openIDPayload) ([]SignOption, error) {
|
||||
signOptions := []SignOption{
|
||||
// set the key id to the token subject
|
||||
sshCertificateKeyIDModifier(claims.Email),
|
||||
}
|
||||
|
||||
name := SanitizeSSHUserPrincipal(claims.Email)
|
||||
if !sshUserRegex.MatchString(name) {
|
||||
return nil, errors.Errorf("invalid principal '%s' from email address '%s'", name, claims.Email)
|
||||
}
|
||||
|
||||
// Admin users will default to user + name but they can be changed by the
|
||||
// user options. Non-admins are only able to sign user certificates.
|
||||
defaults := SSHOptions{
|
||||
CertType: SSHUserCert,
|
||||
Principals: []string{name},
|
||||
}
|
||||
|
||||
if !o.IsAdmin(claims.Email) {
|
||||
signOptions = append(signOptions, sshCertificateOptionsValidator(defaults))
|
||||
}
|
||||
|
||||
// Default to a user with name as principal if not set
|
||||
signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults))
|
||||
|
||||
return append(signOptions,
|
||||
// set the default extensions
|
||||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{o.claimer},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
}
|
||||
|
||||
func getAndDecode(uri string, v interface{}) error {
|
||||
resp, err := http.Get(uri)
|
||||
if err != nil {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -276,7 +278,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.prov.AuthorizeSign(tt.args.token)
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -286,7 +289,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
|||
} else {
|
||||
assert.NotNil(t, got)
|
||||
if tt.name == "admin" {
|
||||
assert.Len(t, 4, got)
|
||||
assert.Len(t, 3, got)
|
||||
} else {
|
||||
assert.Len(t, 5, got)
|
||||
}
|
||||
|
@ -295,6 +298,117 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
var keys jose.JSONWebKeySet
|
||||
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||
|
||||
// Create test provisioners
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p3, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
// Admin + Domains
|
||||
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
|
||||
p3.Domains = []string{"smallstep.com"}
|
||||
|
||||
// Update configuration endpoints and initialize
|
||||
config := Config{Claims: globalProvisionerClaims}
|
||||
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
assert.FatalError(t, p1.Init(config))
|
||||
assert.FatalError(t, p2.Init(config))
|
||||
assert.FatalError(t, p3.Init(config))
|
||||
|
||||
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// Admin email not in domains
|
||||
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// Invalid email
|
||||
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
key, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedUserOptions := &SSHOptions{
|
||||
CertType: "user", Principals: []string{"name"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||
}
|
||||
expectedAdminOptions := &SSHOptions{
|
||||
CertType: "user", Principals: []string{"root"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||
}
|
||||
expectedHostOptions := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"smallstep.com"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
}
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
expected *SSHOptions
|
||||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false},
|
||||
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"admin", p3, args{okAdmin, SSHOptions{}}, expectedAdminOptions, false, false},
|
||||
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}}, expectedAdminOptions, false, false},
|
||||
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}}, expectedAdminOptions, false, false},
|
||||
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
|
||||
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}}, nil, false, true},
|
||||
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}}, nil, false, true},
|
||||
{"fail-email", p3, args{failEmail, SSHOptions{}}, nil, true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignSSHMethod)
|
||||
got, err := tt.prov.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
if tt.wantSignErr {
|
||||
assert.Nil(t, cert)
|
||||
} else {
|
||||
assert.NoError(t, validateSSHCertificate(cert, tt.expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -17,7 +19,7 @@ type Interface interface {
|
|||
GetType() Type
|
||||
GetEncryptedKey() (kid string, key string, ok bool)
|
||||
Init(config Config) error
|
||||
AuthorizeSign(token string) ([]SignOption, error)
|
||||
AuthorizeSign(ctx context.Context, token string) ([]SignOption, error)
|
||||
AuthorizeRenewal(cert *x509.Certificate) error
|
||||
AuthorizeRevoke(token string) error
|
||||
}
|
||||
|
@ -169,3 +171,29 @@ func (l *List) UnmarshalJSON(data []byte) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
|
||||
|
||||
// SanitizeSSHUserPrincipal grabs an email or a string with the format
|
||||
// local@domain and returns a sanitized version of the local, valid to be used
|
||||
// as a user name. If the email starts with a letter between a and z, the
|
||||
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
|
||||
func SanitizeSSHUserPrincipal(email string) string {
|
||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||
email = email[:i]
|
||||
}
|
||||
return strings.Map(func(r rune) rune {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
return r
|
||||
case r >= '0' && r <= '9':
|
||||
return r
|
||||
case r == '-':
|
||||
return '-'
|
||||
case r == '.': // drop dots
|
||||
return -1
|
||||
default:
|
||||
return '_'
|
||||
}
|
||||
}, strings.ToLower(email))
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package provisioner
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestType_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
@ -24,3 +26,29 @@ func TestType_String(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSSHUserPrincipal(t *testing.T) {
|
||||
type args struct {
|
||||
email string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{"simple", args{"foobar"}, "foobar"},
|
||||
{"camelcase", args{"FooBar"}, "foobar"},
|
||||
{"email", args{"foo@example.com"}, "foo"},
|
||||
{"email with dots", args{"foo.bar.zar@example.com"}, "foobarzar"},
|
||||
{"email with dashes", args{"foo-bar-zar@example.com"}, "foo-bar-zar"},
|
||||
{"email with underscores", args{"foo_bar_zar@example.com"}, "foo_bar_zar"},
|
||||
{"email with symbols", args{"Foo.Bar0123456789!#$%&'*+-/=?^_`{|}~;@example.com"}, "foobar0123456789________-___________"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := SanitizeSSHUserPrincipal(tt.args.email); got != tt.want {
|
||||
t.Errorf("SanitizeSSHUserPrincipal() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
306
authority/provisioner/sign_ssh_options.go
Normal file
306
authority/provisioner/sign_ssh_options.go
Normal file
|
@ -0,0 +1,306 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
// SSHUserCert is the string used to represent ssh.UserCert.
|
||||
SSHUserCert = "user"
|
||||
|
||||
// SSHHostCert is the string used to represent ssh.HostCert.
|
||||
SSHHostCert = "host"
|
||||
)
|
||||
|
||||
// SSHCertificateModifier is the interface used to change properties in an SSH
|
||||
// certificate.
|
||||
type SSHCertificateModifier interface {
|
||||
SignOption
|
||||
Modify(cert *ssh.Certificate) error
|
||||
}
|
||||
|
||||
// SSHCertificateOptionModifier is the interface used to add custom options used
|
||||
// to modify the SSH certificate.
|
||||
type SSHCertificateOptionModifier interface {
|
||||
SignOption
|
||||
Option(o SSHOptions) SSHCertificateModifier
|
||||
}
|
||||
|
||||
// SSHCertificateValidator is the interface used to validate an SSH certificate.
|
||||
type SSHCertificateValidator interface {
|
||||
SignOption
|
||||
Valid(cert *ssh.Certificate) error
|
||||
}
|
||||
|
||||
// SSHCertificateOptionsValidator is the interface used to validate the custom
|
||||
// options used to modify the SSH certificate.
|
||||
type SSHCertificateOptionsValidator interface {
|
||||
SignOption
|
||||
Valid(got SSHOptions) error
|
||||
}
|
||||
|
||||
// SSHOptions contains the options that can be passed to the SignSSH method.
|
||||
type SSHOptions struct {
|
||||
CertType string `json:"certType"`
|
||||
Principals []string `json:"principals"`
|
||||
ValidAfter TimeDuration `json:"validAfter,omitempty"`
|
||||
ValidBefore TimeDuration `json:"validBefore,omitempty"`
|
||||
}
|
||||
|
||||
// Type returns the uint32 representation of the CertType.
|
||||
func (o SSHOptions) Type() uint32 {
|
||||
return sshCertTypeUInt32(o.CertType)
|
||||
}
|
||||
|
||||
// Modify implements SSHCertificateModifier and sets the SSHOption in the ssh.Certificate.
|
||||
func (o SSHOptions) Modify(cert *ssh.Certificate) error {
|
||||
switch o.CertType {
|
||||
case "": // ignore
|
||||
case SSHUserCert:
|
||||
cert.CertType = ssh.UserCert
|
||||
case SSHHostCert:
|
||||
cert.CertType = ssh.HostCert
|
||||
default:
|
||||
return errors.Errorf("ssh certificate has an unknown type: %s", o.CertType)
|
||||
}
|
||||
cert.ValidPrincipals = o.Principals
|
||||
if !o.ValidAfter.IsZero() {
|
||||
cert.ValidAfter = uint64(o.ValidAfter.Time().Unix())
|
||||
}
|
||||
if !o.ValidBefore.IsZero() {
|
||||
cert.ValidBefore = uint64(o.ValidBefore.Time().Unix())
|
||||
}
|
||||
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
|
||||
return errors.New("ssh certificate valid after cannot be greater than valid before")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// match compares two SSHOptions and return an error if they don't match. It
|
||||
// ignores zero values.
|
||||
func (o SSHOptions) match(got SSHOptions) error {
|
||||
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
|
||||
return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
|
||||
}
|
||||
if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) {
|
||||
return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
|
||||
}
|
||||
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
|
||||
return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
|
||||
}
|
||||
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
|
||||
return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given
|
||||
// Key ID in the SSH certificate.
|
||||
type sshCertificateKeyIDModifier string
|
||||
|
||||
func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error {
|
||||
cert.KeyId = string(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertificateCertTypeModifier is an SSHCertificateModifier that sets the
|
||||
// certificate type to the SSH certificate.
|
||||
type sshCertificateCertTypeModifier string
|
||||
|
||||
func (m sshCertificateCertTypeModifier) Modify(cert *ssh.Certificate) error {
|
||||
cert.CertType = sshCertTypeUInt32(string(m))
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertificatePrincipalsModifier is an SSHCertificateModifier that sets the
|
||||
// principals to the SSH certificate.
|
||||
type sshCertificatePrincipalsModifier []string
|
||||
|
||||
func (m sshCertificatePrincipalsModifier) Modify(cert *ssh.Certificate) error {
|
||||
cert.ValidPrincipals = []string(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the
|
||||
// ValidAfter in the SSH certificate.
|
||||
type sshCertificateValidAfterModifier uint64
|
||||
|
||||
func (m sshCertificateValidAfterModifier) Modify(cert *ssh.Certificate) error {
|
||||
cert.ValidAfter = uint64(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertificateValidBeforeModifier is an SSHCertificateModifier that sets the
|
||||
// ValidBefore in the SSH certificate.
|
||||
type sshCertificateValidBeforeModifier uint64
|
||||
|
||||
func (m sshCertificateValidBeforeModifier) Modify(cert *ssh.Certificate) error {
|
||||
cert.ValidBefore = uint64(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshCertificateDefaultModifier implements a SSHCertificateModifier that
|
||||
// modifies the certificate with the given options if they are not set.
|
||||
type sshCertificateDefaultsModifier SSHOptions
|
||||
|
||||
// Modify implements the SSHCertificateModifier interface.
|
||||
func (m sshCertificateDefaultsModifier) Modify(cert *ssh.Certificate) error {
|
||||
if cert.CertType == 0 {
|
||||
cert.CertType = sshCertTypeUInt32(m.CertType)
|
||||
}
|
||||
if len(cert.ValidPrincipals) == 0 {
|
||||
cert.ValidPrincipals = m.Principals
|
||||
}
|
||||
if cert.ValidAfter == 0 && !m.ValidAfter.IsZero() {
|
||||
cert.ValidAfter = uint64(m.ValidAfter.Unix())
|
||||
}
|
||||
if cert.ValidBefore == 0 && !m.ValidBefore.IsZero() {
|
||||
cert.ValidBefore = uint64(m.ValidBefore.Unix())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sshDefaultExtensionModifier implements an SSHCertificateModifier that sets
|
||||
// the default extensions in an SSH certificate.
|
||||
type sshDefaultExtensionModifier struct{}
|
||||
|
||||
func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate) error {
|
||||
switch cert.CertType {
|
||||
// Default to no extensions for HostCert.
|
||||
case ssh.HostCert:
|
||||
return nil
|
||||
case ssh.UserCert:
|
||||
if cert.Extensions == nil {
|
||||
cert.Extensions = make(map[string]string)
|
||||
}
|
||||
cert.Extensions["permit-X11-forwarding"] = ""
|
||||
cert.Extensions["permit-agent-forwarding"] = ""
|
||||
cert.Extensions["permit-port-forwarding"] = ""
|
||||
cert.Extensions["permit-pty"] = ""
|
||||
cert.Extensions["permit-user-rc"] = ""
|
||||
return nil
|
||||
default:
|
||||
return errors.New("ssh certificate type has not been set or is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
// sshCertificateValidityModifier is a SSHCertificateModifier checks the
|
||||
// validity bounds, setting them if they are not provided. It will fail if a
|
||||
// CertType has not been set or is not valid.
|
||||
type sshCertificateValidityModifier struct {
|
||||
*Claimer
|
||||
}
|
||||
|
||||
func (m *sshCertificateValidityModifier) Modify(cert *ssh.Certificate) error {
|
||||
var d, min, max time.Duration
|
||||
switch cert.CertType {
|
||||
case ssh.UserCert:
|
||||
d = m.DefaultUserSSHCertDuration()
|
||||
min = m.MinUserSSHCertDuration()
|
||||
max = m.MaxUserSSHCertDuration()
|
||||
case ssh.HostCert:
|
||||
d = m.DefaultHostSSHCertDuration()
|
||||
min = m.MinHostSSHCertDuration()
|
||||
max = m.MaxHostSSHCertDuration()
|
||||
case 0:
|
||||
return errors.New("ssh certificate type has not been set")
|
||||
default:
|
||||
return errors.Errorf("unknown ssh certificate type %d", cert.CertType)
|
||||
}
|
||||
|
||||
if cert.ValidAfter == 0 {
|
||||
cert.ValidAfter = uint64(now().Unix())
|
||||
}
|
||||
if cert.ValidBefore == 0 {
|
||||
t := time.Unix(int64(cert.ValidAfter), 0)
|
||||
cert.ValidBefore = uint64(t.Add(d).Unix())
|
||||
}
|
||||
|
||||
diff := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
|
||||
switch {
|
||||
case diff < min:
|
||||
return errors.Errorf("ssh certificate duration cannot be lower than %s", min)
|
||||
case diff > max:
|
||||
return errors.Errorf("ssh certificate duration cannot be greater than %s", max)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// sshCertificateOptionsValidator validates the user SSHOptions with the ones
|
||||
// usually present in the token.
|
||||
type sshCertificateOptionsValidator SSHOptions
|
||||
|
||||
// Valid implements SSHCertificateOptionsValidator and returns nil if both
|
||||
// SSHOptions match.
|
||||
func (v sshCertificateOptionsValidator) Valid(got SSHOptions) error {
|
||||
want := SSHOptions(v)
|
||||
return want.match(got)
|
||||
}
|
||||
|
||||
// sshCertificateDefaultValidator implements a simple validator for all the
|
||||
// fields in the SSH certificate.
|
||||
type sshCertificateDefaultValidator struct{}
|
||||
|
||||
// Valid returns an error if the given certificate does not contain the necessary fields.
|
||||
func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
|
||||
switch {
|
||||
case len(cert.Nonce) == 0:
|
||||
return errors.New("ssh certificate nonce cannot be empty")
|
||||
case cert.Key == nil:
|
||||
return errors.New("ssh certificate key cannot be nil")
|
||||
case cert.Serial == 0:
|
||||
return errors.New("ssh certificate serial cannot be 0")
|
||||
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
|
||||
return errors.Errorf("ssh certificate has an unknown type: %d", cert.CertType)
|
||||
case cert.KeyId == "":
|
||||
return errors.New("ssh certificate key id cannot be empty")
|
||||
case len(cert.ValidPrincipals) == 0:
|
||||
return errors.New("ssh certificate valid principals cannot be empty")
|
||||
case cert.ValidAfter == 0:
|
||||
return errors.New("ssh certificate valid after cannot be 0")
|
||||
case cert.ValidBefore == 0:
|
||||
return errors.New("ssh certificate valid before cannot be 0")
|
||||
case cert.CertType == ssh.UserCert && len(cert.Extensions) == 0:
|
||||
return errors.New("ssh certificate extensions cannot be empty")
|
||||
case cert.SignatureKey == nil:
|
||||
return errors.New("ssh certificate signature key cannot be nil")
|
||||
case cert.Signature == nil:
|
||||
return errors.New("ssh certificate signature cannot be nil")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// sshCertTypeUInt32
|
||||
func sshCertTypeUInt32(ct string) uint32 {
|
||||
switch ct {
|
||||
case SSHUserCert:
|
||||
return ssh.UserCert
|
||||
case SSHHostCert:
|
||||
return ssh.HostCert
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// containsAllMembers reports whether all members of subgroup are within group.
|
||||
func containsAllMembers(group, subgroup []string) bool {
|
||||
lg, lsg := len(group), len(subgroup)
|
||||
if lsg > lg || (lg > 0 && lsg == 0) {
|
||||
return false
|
||||
}
|
||||
visit := make(map[string]struct{}, lg)
|
||||
for i := 0; i < lg; i++ {
|
||||
visit[group[i]] = struct{}{}
|
||||
}
|
||||
for i := 0; i < lsg; i++ {
|
||||
if _, ok := visit[subgroup[i]]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
125
authority/provisioner/ssh_test.go
Normal file
125
authority/provisioner/ssh_test.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func validateSSHCertificate(cert *ssh.Certificate, opts *SSHOptions) error {
|
||||
switch {
|
||||
case cert == nil:
|
||||
return fmt.Errorf("certificate is nil")
|
||||
case cert.Signature == nil:
|
||||
return fmt.Errorf("certificate signature is nil")
|
||||
case cert.SignatureKey == nil:
|
||||
return fmt.Errorf("certificate signature is nil")
|
||||
case !reflect.DeepEqual(cert.ValidPrincipals, opts.Principals):
|
||||
return fmt.Errorf("certificate principals are not equal, want %v, got %v", opts.Principals, cert.ValidPrincipals)
|
||||
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
|
||||
return fmt.Errorf("certificate type %v is not valid", cert.CertType)
|
||||
case opts.CertType == "user" && cert.CertType != ssh.UserCert:
|
||||
return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.UserCert, cert.CertType)
|
||||
case opts.CertType == "host" && cert.CertType != ssh.HostCert:
|
||||
return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.HostCert, cert.CertType)
|
||||
case cert.ValidAfter != uint64(opts.ValidAfter.Unix()):
|
||||
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
|
||||
case cert.ValidBefore != uint64(opts.ValidBefore.Unix()):
|
||||
return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(int64(cert.ValidAfter), 0))
|
||||
case opts.CertType == "user" && len(cert.Extensions) != 5:
|
||||
return fmt.Errorf("certificate extensions number is invalid, want 5, got %d", len(cert.Extensions))
|
||||
case opts.CertType == "host" && len(cert.Extensions) != 0:
|
||||
return fmt.Errorf("certificate extensions number is invalid, want 0, got %d", len(cert.Extensions))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOption, signKey crypto.Signer) (*ssh.Certificate, error) {
|
||||
pub, err := ssh.NewPublicKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mods []SSHCertificateModifier
|
||||
var validators []SSHCertificateValidator
|
||||
|
||||
for _, op := range signOpts {
|
||||
switch o := op.(type) {
|
||||
// modify the ssh.Certificate
|
||||
case SSHCertificateModifier:
|
||||
mods = append(mods, o)
|
||||
// modify the ssh.Certificate given the SSHOptions
|
||||
case SSHCertificateOptionModifier:
|
||||
mods = append(mods, o.Option(opts))
|
||||
// validate the ssh.Certificate
|
||||
case SSHCertificateValidator:
|
||||
validators = append(validators, o)
|
||||
// validate the given SSHOptions
|
||||
case SSHCertificateOptionsValidator:
|
||||
if err := o.Valid(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("signSSH: invalid extra option type %T", o)
|
||||
}
|
||||
}
|
||||
|
||||
// Build base certificate with the key and some random values
|
||||
cert := &ssh.Certificate{
|
||||
Nonce: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0},
|
||||
Key: pub,
|
||||
Serial: 1234567890,
|
||||
}
|
||||
|
||||
// Use opts to modify the certificate
|
||||
if err := opts.Modify(cert); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use provisioner modifiers
|
||||
for _, m := range mods {
|
||||
if err := m.Modify(cert); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get signer from authority keys
|
||||
var signer ssh.Signer
|
||||
switch cert.CertType {
|
||||
case ssh.UserCert:
|
||||
signer, err = ssh.NewSignerFromSigner(signKey)
|
||||
case ssh.HostCert:
|
||||
signer, err = ssh.NewSignerFromSigner(signKey)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected ssh certificate type: %d", cert.CertType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cert.SignatureKey = signer.PublicKey()
|
||||
|
||||
// Get bytes for signing trailing the signature length.
|
||||
data := cert.Marshal()
|
||||
data = data[:len(data)-4]
|
||||
|
||||
// Sign the certificate
|
||||
sig, err := signer.Sign(rand.Reader, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cert.Signature = sig
|
||||
|
||||
// User provisioners validators
|
||||
for _, v := range validators {
|
||||
if err := v.Valid(cert); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
|
@ -57,6 +57,17 @@ func (t *TimeDuration) SetTime(tt time.Time) {
|
|||
t.t, t.d = tt, 0
|
||||
}
|
||||
|
||||
// IsZero returns true the TimeDuration represents the zero value, false
|
||||
// otherwise.
|
||||
func (t *TimeDuration) IsZero() bool {
|
||||
return t.t.IsZero() && t.d == 0
|
||||
}
|
||||
|
||||
// Equal returns if t and other are equal.
|
||||
func (t *TimeDuration) Equal(other *TimeDuration) bool {
|
||||
return t.t.Equal(other.t) && t.d == other.d
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface. If the time is set it
|
||||
// will return the time in RFC 3339 format if not it will return the duration
|
||||
// string.
|
||||
|
@ -64,7 +75,7 @@ func (t TimeDuration) MarshalJSON() ([]byte, error) {
|
|||
switch {
|
||||
case t.t.IsZero():
|
||||
if t.d == 0 {
|
||||
return []byte("null"), nil
|
||||
return []byte(`""`), nil
|
||||
}
|
||||
return json.Marshal(t.d.String())
|
||||
default:
|
||||
|
@ -102,11 +113,16 @@ func (t *TimeDuration) UnmarshalJSON(data []byte) error {
|
|||
return errors.Errorf("failed to parse %s", data)
|
||||
}
|
||||
|
||||
// Time calculates the embedded time.Time, sets it if necessary, and returns it.
|
||||
// Time calculates the time if needed and returns it.
|
||||
func (t *TimeDuration) Time() time.Time {
|
||||
return t.RelativeTime(now())
|
||||
}
|
||||
|
||||
// Unix calculates the time if needed it and returns the Unix time in seconds.
|
||||
func (t *TimeDuration) Unix() int64 {
|
||||
return t.RelativeTime(now()).Unix()
|
||||
}
|
||||
|
||||
// RelativeTime returns the embedded time.Time or the base time plus the
|
||||
// duration if this is not zero.
|
||||
func (t *TimeDuration) RelativeTime(base time.Time) time.Time {
|
||||
|
|
|
@ -6,6 +6,17 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func mockNow() (time.Time, func()) {
|
||||
tm := time.Unix(1584198566, 535897000).UTC()
|
||||
nowFn := now
|
||||
now = func() time.Time {
|
||||
return tm
|
||||
}
|
||||
return tm, func() {
|
||||
now = nowFn
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTimeDuration(t *testing.T) {
|
||||
tm := time.Unix(1584198566, 535897000).UTC()
|
||||
type args struct {
|
||||
|
@ -137,7 +148,7 @@ func TestTimeDuration_MarshalJSON(t *testing.T) {
|
|||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"null", TimeDuration{}, []byte("null"), false},
|
||||
{"empty", TimeDuration{}, []byte(`""`), false},
|
||||
{"timestamp", TimeDuration{t: tm}, []byte(`"2020-03-14T15:09:26.535897Z"`), false},
|
||||
{"duration", TimeDuration{d: 1 * time.Hour}, []byte(`"1h0m0s"`), false},
|
||||
{"fail", TimeDuration{t: time.Date(-1, 0, 0, 0, 0, 0, 0, time.UTC)}, nil, true},
|
||||
|
@ -166,7 +177,7 @@ func TestTimeDuration_UnmarshalJSON(t *testing.T) {
|
|||
want *TimeDuration
|
||||
wantErr bool
|
||||
}{
|
||||
{"null", args{[]byte("null")}, &TimeDuration{}, false},
|
||||
{"empty", args{[]byte(`""`)}, &TimeDuration{}, false},
|
||||
{"timestamp", args{[]byte(`"2020-03-14T15:09:26.535897Z"`)}, &TimeDuration{t: time.Unix(1584198566, 535897000).UTC()}, false},
|
||||
{"duration", args{[]byte(`"1h"`)}, &TimeDuration{d: time.Hour}, false},
|
||||
{"fail", args{[]byte("123")}, &TimeDuration{}, true},
|
||||
|
@ -186,15 +197,8 @@ func TestTimeDuration_UnmarshalJSON(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTimeDuration_Time(t *testing.T) {
|
||||
nowFn := now
|
||||
defer func() {
|
||||
now = nowFn
|
||||
now()
|
||||
}()
|
||||
tm := time.Unix(1584198566, 535897000).UTC()
|
||||
now = func() time.Time {
|
||||
return tm
|
||||
}
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
tests := []struct {
|
||||
name string
|
||||
timeDuration *TimeDuration
|
||||
|
@ -211,6 +215,30 @@ func TestTimeDuration_Time(t *testing.T) {
|
|||
got := tt.timeDuration.Time()
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("TimeDuration.Time() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeDuration_Unix(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
tests := []struct {
|
||||
name string
|
||||
timeDuration *TimeDuration
|
||||
want int64
|
||||
}{
|
||||
{"zero", nil, -62135596800},
|
||||
{"zero", &TimeDuration{}, -62135596800},
|
||||
{"timestamp", &TimeDuration{t: tm}, 1584198566},
|
||||
{"local", &TimeDuration{t: tm.Local()}, 1584198566},
|
||||
{"duration", &TimeDuration{d: 1 * time.Hour}, 1584202166},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.timeDuration.Unix()
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("TimeDuration.Unix() = %v, want %v", got, tt.want)
|
||||
|
||||
}
|
||||
})
|
||||
|
@ -218,15 +246,8 @@ func TestTimeDuration_Time(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTimeDuration_String(t *testing.T) {
|
||||
nowFn := now
|
||||
defer func() {
|
||||
now = nowFn
|
||||
now()
|
||||
}()
|
||||
tm := time.Unix(1584198566, 535897000).UTC()
|
||||
now = func() time.Time {
|
||||
return tm
|
||||
}
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
type fields struct {
|
||||
t time.Time
|
||||
d time.Duration
|
||||
|
|
|
@ -480,6 +480,54 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
|
|||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func generateSimpleSSHUserToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
|
||||
return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SSHOptions{
|
||||
CertType: "user",
|
||||
Principals: []string{"name"},
|
||||
}, jwk)
|
||||
}
|
||||
|
||||
func generateSimpleSSHHostToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
|
||||
return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SSHOptions{
|
||||
CertType: "host",
|
||||
Principals: []string{"smallstep.com"},
|
||||
}, jwk)
|
||||
}
|
||||
|
||||
func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *SSHOptions, jwk *jose.JSONWebKey) (string, error) {
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
id, err := randutil.ASCII(64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims := struct {
|
||||
jose.Claims
|
||||
Step *stepPayload `json:"step,omitempty"`
|
||||
}{
|
||||
Claims: jose.Claims{
|
||||
ID: id,
|
||||
Subject: sub,
|
||||
Issuer: iss,
|
||||
IssuedAt: jose.NewNumericDate(iat),
|
||||
NotBefore: jose.NewNumericDate(iat),
|
||||
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
|
||||
Audience: []string{aud},
|
||||
},
|
||||
Step: &stepPayload{
|
||||
SSH: sshOpts,
|
||||
},
|
||||
}
|
||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
|
|
|
@ -13,7 +13,7 @@ func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
|||
key, ok := a.provisioners.LoadEncryptedKey(kid)
|
||||
if !ok {
|
||||
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
|
||||
http.StatusNotFound, context{}}
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
|
|||
p, ok := a.provisioners.LoadByCertificate(crt)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("provisioner not found"),
|
||||
http.StatusNotFound, context{}}
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
a: a,
|
||||
kid: "foo",
|
||||
err: &apiError{errors.Errorf("encrypted key with kid foo was not found"),
|
||||
http.StatusNotFound, context{}},
|
||||
http.StatusNotFound, apiCtx{}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
|
@ -12,13 +12,13 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
|||
val, ok := a.certificates.Load(sum)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("certificate with fingerprint %s was not found", sum),
|
||||
http.StatusNotFound, context{}}
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
}
|
||||
|
||||
crt, ok := val.(*x509.Certificate)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
}
|
||||
return crt, nil
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error)
|
|||
if !ok {
|
||||
federation = nil
|
||||
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
return false
|
||||
}
|
||||
federation = append(federation, crt)
|
||||
|
|
|
@ -19,8 +19,8 @@ func TestRoot(t *testing.T) {
|
|||
sum string
|
||||
err *apiError
|
||||
}{
|
||||
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
|
||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}},
|
||||
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, apiCtx{}}},
|
||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, apiCtx{}}},
|
||||
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
|
||||
}
|
||||
|
||||
|
|
239
authority/ssh.go
Normal file
239
authority/ssh.go
Normal file
|
@ -0,0 +1,239 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
// SSHAddUserPrincipal is the principal that will run the add user command.
|
||||
// Defaults to "provisioner" but it can be changed in the configuration.
|
||||
SSHAddUserPrincipal = "provisioner"
|
||||
|
||||
// SSHAddUserCommand is the default command to run to add a new user.
|
||||
// Defaults to "sudo useradd -m <principal>; nc -q0 localhost 22" but it can be changed in the
|
||||
// configuration. The string "<principal>" will be replace by the new
|
||||
// principal to add.
|
||||
SSHAddUserCommand = "sudo useradd -m <principal>; nc -q0 localhost 22"
|
||||
)
|
||||
|
||||
// SignSSH creates a signed SSH certificate with the given public key and options.
|
||||
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
var mods []provisioner.SSHCertificateModifier
|
||||
var validators []provisioner.SSHCertificateValidator
|
||||
|
||||
for _, op := range signOpts {
|
||||
switch o := op.(type) {
|
||||
// modify the ssh.Certificate
|
||||
case provisioner.SSHCertificateModifier:
|
||||
mods = append(mods, o)
|
||||
// modify the ssh.Certificate given the SSHOptions
|
||||
case provisioner.SSHCertificateOptionModifier:
|
||||
mods = append(mods, o.Option(opts))
|
||||
// validate the ssh.Certificate
|
||||
case provisioner.SSHCertificateValidator:
|
||||
validators = append(validators, o)
|
||||
// validate the given SSHOptions
|
||||
case provisioner.SSHCertificateOptionsValidator:
|
||||
if err := o.Valid(opts); err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||
}
|
||||
default:
|
||||
return nil, &apiError{
|
||||
err: errors.Errorf("signSSH: invalid extra option type %T", o),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSH: error reading random number"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
// Build base certificate with the key and some random values
|
||||
cert := &ssh.Certificate{
|
||||
Nonce: []byte(nonce),
|
||||
Key: key,
|
||||
Serial: serial,
|
||||
}
|
||||
|
||||
// Use opts to modify the certificate
|
||||
if err := opts.Modify(cert); err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||
}
|
||||
|
||||
// Use provisioner modifiers
|
||||
for _, m := range mods {
|
||||
if err := m.Modify(cert); err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||
}
|
||||
}
|
||||
|
||||
// Get signer from authority keys
|
||||
var signer ssh.Signer
|
||||
switch cert.CertType {
|
||||
case ssh.UserCert:
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSH: user certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
if signer, err = ssh.NewSignerFromSigner(a.sshCAUserCertSignKey); err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSH: error creating signer"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
case ssh.HostCert:
|
||||
if a.sshCAHostCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSH: host certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
if signer, err = ssh.NewSignerFromSigner(a.sshCAHostCertSignKey); err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSH: error creating signer"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, &apiError{
|
||||
err: errors.Errorf("signSSH: unexpected ssh certificate type: %d", cert.CertType),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
cert.SignatureKey = signer.PublicKey()
|
||||
|
||||
// Get bytes for signing trailing the signature length.
|
||||
data := cert.Marshal()
|
||||
data = data[:len(data)-4]
|
||||
|
||||
// Sign the certificate
|
||||
sig, err := signer.Sign(rand.Reader, data)
|
||||
if err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSH: error signing certificate"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
cert.Signature = sig
|
||||
|
||||
// User provisioners validators
|
||||
for _, v := range validators {
|
||||
if err := v.Valid(cert); err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusForbidden}
|
||||
}
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// SignSSHAddUser signs a certificate that provisions a new user in a server.
|
||||
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if a.sshCAUserCertSignKey == nil {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSHAddUser: user certificate signing is not enabled"),
|
||||
code: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
if subject.CertType != ssh.UserCert {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSHProxy: certificate is not a user certificate"),
|
||||
code: http.StatusForbidden,
|
||||
}
|
||||
}
|
||||
if len(subject.ValidPrincipals) != 1 {
|
||||
return nil, &apiError{
|
||||
err: errors.New("signSSHProxy: certificate does not have only one principal"),
|
||||
code: http.StatusForbidden,
|
||||
}
|
||||
}
|
||||
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, &apiError{err: err, code: http.StatusInternalServerError}
|
||||
}
|
||||
|
||||
var serial uint64
|
||||
if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSHProxy: error reading random number"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
signer, err := ssh.NewSignerFromSigner(a.sshCAUserCertSignKey)
|
||||
if err != nil {
|
||||
return nil, &apiError{
|
||||
err: errors.Wrap(err, "signSSHProxy: error creating signer"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
principal := subject.ValidPrincipals[0]
|
||||
addUserPrincipal := a.getAddUserPrincipal()
|
||||
|
||||
cert := &ssh.Certificate{
|
||||
Nonce: []byte(nonce),
|
||||
Key: key,
|
||||
Serial: serial,
|
||||
CertType: ssh.UserCert,
|
||||
KeyId: principal + "-" + addUserPrincipal,
|
||||
ValidPrincipals: []string{addUserPrincipal},
|
||||
ValidAfter: subject.ValidAfter,
|
||||
ValidBefore: subject.ValidBefore,
|
||||
Permissions: ssh.Permissions{
|
||||
CriticalOptions: map[string]string{
|
||||
"force-command": a.getAddUserCommand(principal),
|
||||
},
|
||||
},
|
||||
SignatureKey: signer.PublicKey(),
|
||||
}
|
||||
|
||||
// Get bytes for signing trailing the signature length.
|
||||
data := cert.Marshal()
|
||||
data = data[:len(data)-4]
|
||||
|
||||
// Sign the certificate
|
||||
sig, err := signer.Sign(rand.Reader, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cert.Signature = sig
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func (a *Authority) getAddUserPrincipal() (cmd string) {
|
||||
if a.config.SSH.AddUserPrincipal == "" {
|
||||
return SSHAddUserPrincipal
|
||||
}
|
||||
return a.config.SSH.AddUserPrincipal
|
||||
}
|
||||
|
||||
func (a *Authority) getAddUserCommand(principal string) string {
|
||||
var cmd string
|
||||
if a.config.SSH.AddUserCommand == "" {
|
||||
cmd = SSHAddUserCommand
|
||||
} else {
|
||||
cmd = a.config.SSH.AddUserCommand
|
||||
}
|
||||
return strings.Replace(cmd, "<principal>", principal, -1)
|
||||
}
|
252
authority/ssh_test.go
Normal file
252
authority/ssh_test.go
Normal file
|
@ -0,0 +1,252 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type sshTestModifier ssh.Certificate
|
||||
|
||||
func (m sshTestModifier) Modify(cert *ssh.Certificate) error {
|
||||
if m.CertType != 0 {
|
||||
cert.CertType = m.CertType
|
||||
}
|
||||
if m.KeyId != "" {
|
||||
cert.KeyId = m.KeyId
|
||||
}
|
||||
if m.ValidAfter != 0 {
|
||||
cert.ValidAfter = m.ValidAfter
|
||||
}
|
||||
if m.ValidBefore != 0 {
|
||||
cert.ValidBefore = m.ValidBefore
|
||||
}
|
||||
if len(m.ValidPrincipals) != 0 {
|
||||
cert.ValidPrincipals = m.ValidPrincipals
|
||||
}
|
||||
if m.Permissions.CriticalOptions != nil {
|
||||
cert.Permissions.CriticalOptions = m.Permissions.CriticalOptions
|
||||
}
|
||||
if m.Permissions.Extensions != nil {
|
||||
cert.Permissions.Extensions = m.Permissions.Extensions
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type sshTestCertModifier string
|
||||
|
||||
func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error {
|
||||
if m == "" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(string(m))
|
||||
}
|
||||
|
||||
type sshTestCertValidator string
|
||||
|
||||
func (v sshTestCertValidator) Valid(crt *ssh.Certificate) error {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(string(v))
|
||||
}
|
||||
|
||||
type sshTestOptionsValidator string
|
||||
|
||||
func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(string(v))
|
||||
}
|
||||
|
||||
type sshTestOptionsModifier string
|
||||
|
||||
func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertificateModifier {
|
||||
return sshTestCertModifier(string(m))
|
||||
}
|
||||
|
||||
func TestAuthority_SignSSH(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.FatalError(t, err)
|
||||
pub, err := ssh.NewPublicKey(key.Public())
|
||||
assert.FatalError(t, err)
|
||||
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userOptions := sshTestModifier{
|
||||
CertType: ssh.UserCert,
|
||||
}
|
||||
hostOptions := sshTestModifier{
|
||||
CertType: ssh.HostCert,
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
type fields struct {
|
||||
sshCAUserCertSignKey crypto.Signer
|
||||
sshCAHostCertSignKey crypto.Signer
|
||||
}
|
||||
type args struct {
|
||||
key ssh.PublicKey
|
||||
opts provisioner.SSHOptions
|
||||
signOpts []provisioner.SignOption
|
||||
}
|
||||
type want struct {
|
||||
CertType uint32
|
||||
Principals []string
|
||||
ValidAfter uint64
|
||||
ValidBefore uint64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want want
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions}}, want{CertType: ssh.UserCert}, false},
|
||||
{"ok-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{hostOptions}}, want{CertType: ssh.HostCert}, false},
|
||||
{"ok-opts-type-user", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert}, false},
|
||||
{"ok-opts-type-host", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert}, false},
|
||||
{"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
|
||||
{"ok-opts-principals", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false},
|
||||
{"ok-opts-valid-after", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false},
|
||||
{"ok-opts-valid-before", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false},
|
||||
{"ok-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false},
|
||||
{"ok-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false},
|
||||
{"ok-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false},
|
||||
{"ok-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false},
|
||||
{"fail-opts-type", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{CertType: "foo"}, []provisioner.SignOption{}}, want{}, true},
|
||||
{"fail-cert-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("an error")}}, want{}, true},
|
||||
{"fail-cert-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("an error")}}, want{}, true},
|
||||
{"fail-opts-validator", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("an error")}}, want{}, true},
|
||||
{"fail-opts-modifier", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("an error")}}, want{}, true},
|
||||
{"fail-bad-sign-options", fields{signKey, signKey}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, "wrong type"}}, want{}, true},
|
||||
{"fail-no-user-key", fields{nil, signKey}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{}, true},
|
||||
{"fail-no-host-key", fields{signKey, nil}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{}, true},
|
||||
{"fail-bad-type", fields{signKey, nil}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{sshTestModifier{CertType: 0}}}, want{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := testAuthority(t)
|
||||
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
|
||||
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
|
||||
|
||||
got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil && assert.NotNil(t, got) {
|
||||
assert.Equals(t, tt.want.CertType, got.CertType)
|
||||
assert.Equals(t, tt.want.Principals, got.ValidPrincipals)
|
||||
assert.Equals(t, tt.want.ValidAfter, got.ValidAfter)
|
||||
assert.Equals(t, tt.want.ValidBefore, got.ValidBefore)
|
||||
assert.NotNil(t, got.Key)
|
||||
assert.NotNil(t, got.Nonce)
|
||||
assert.NotEquals(t, 0, got.Serial)
|
||||
assert.NotNil(t, got.Signature)
|
||||
assert.NotNil(t, got.SignatureKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_SignSSHAddUser(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.FatalError(t, err)
|
||||
pub, err := ssh.NewPublicKey(key.Public())
|
||||
assert.FatalError(t, err)
|
||||
signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
sshCAUserCertSignKey crypto.Signer
|
||||
sshCAHostCertSignKey crypto.Signer
|
||||
addUserPrincipal string
|
||||
addUserCommand string
|
||||
}
|
||||
type args struct {
|
||||
key ssh.PublicKey
|
||||
subject *ssh.Certificate
|
||||
}
|
||||
type want struct {
|
||||
CertType uint32
|
||||
Principals []string
|
||||
ValidAfter uint64
|
||||
ValidBefore uint64
|
||||
ForceCommand string
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
validCert := &ssh.Certificate{
|
||||
CertType: ssh.UserCert,
|
||||
ValidPrincipals: []string{"user"},
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}
|
||||
validWant := want{
|
||||
CertType: ssh.UserCert,
|
||||
Principals: []string{"provisioner"},
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
ForceCommand: "sudo useradd -m user; nc -q0 localhost 22",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want want
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{signKey, signKey, "", ""}, args{pub, validCert}, validWant, false},
|
||||
{"ok-no-host-key", fields{signKey, nil, "", ""}, args{pub, validCert}, validWant, false},
|
||||
{"ok-custom-principal", fields{signKey, signKey, "my-principal", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "sudo useradd -m user; nc -q0 localhost 22"}, false},
|
||||
{"ok-custom-command", fields{signKey, signKey, "", "foo <principal> <principal>"}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"provisioner"}, ForceCommand: "foo user user"}, false},
|
||||
{"ok-custom-principal-and-command", fields{signKey, signKey, "my-principal", "foo <principal> <principal>"}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "foo user user"}, false},
|
||||
{"fail-no-user-key", fields{nil, signKey, "", ""}, args{pub, validCert}, want{}, true},
|
||||
{"fail-no-user-cert", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.HostCert, ValidPrincipals: []string{"foo"}}}, want{}, true},
|
||||
{"fail-no-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{}}}, want{}, true},
|
||||
{"fail-many-principals", fields{signKey, signKey, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}}}, want{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := testAuthority(t)
|
||||
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
|
||||
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
|
||||
a.config.SSH = &SSHConfig{
|
||||
AddUserPrincipal: tt.fields.addUserPrincipal,
|
||||
AddUserCommand: tt.fields.addUserCommand,
|
||||
}
|
||||
got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil && assert.NotNil(t, got) {
|
||||
assert.Equals(t, tt.want.CertType, got.CertType)
|
||||
assert.Equals(t, tt.want.Principals, got.ValidPrincipals)
|
||||
assert.Equals(t, tt.args.subject.ValidPrincipals[0]+"-"+tt.want.Principals[0], got.KeyId)
|
||||
assert.Equals(t, tt.want.ValidAfter, got.ValidAfter)
|
||||
assert.Equals(t, tt.want.ValidBefore, got.ValidBefore)
|
||||
assert.Equals(t, map[string]string{"force-command": tt.want.ForceCommand}, got.CriticalOptions)
|
||||
assert.Equals(t, nil, got.Extensions)
|
||||
assert.NotNil(t, got.Key)
|
||||
assert.NotNil(t, got.Nonce)
|
||||
assert.NotEquals(t, 0, got.Serial)
|
||||
assert.NotNil(t, got.Signature)
|
||||
assert.NotNil(t, got.SignatureKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -58,7 +58,7 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
|
|||
// Sign creates a signed certificate from a certificate signing request.
|
||||
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
|
||||
var (
|
||||
errContext = context{"csr": csr, "signOptions": signOpts}
|
||||
errContext = apiCtx{"csr": csr, "signOptions": signOpts}
|
||||
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
|
||||
certValidators = []provisioner.CertificateValidator{}
|
||||
issIdentity = a.intermediateIdentity
|
||||
|
@ -181,23 +181,23 @@ func (a *Authority) Renew(oldCert *x509.Certificate) (*x509.Certificate, *x509.C
|
|||
leaf, err := x509util.NewLeafProfileWithTemplate(newCert,
|
||||
issIdentity.Crt, issIdentity.Key)
|
||||
if err != nil {
|
||||
return nil, nil, &apiError{err, http.StatusInternalServerError, context{}}
|
||||
return nil, nil, &apiError{err, http.StatusInternalServerError, apiCtx{}}
|
||||
}
|
||||
crtBytes, err := leaf.CreateCertificate()
|
||||
if err != nil {
|
||||
return nil, nil, &apiError{errors.Wrap(err, "error renewing certificate from existing server certificate"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
}
|
||||
|
||||
serverCert, err := x509.ParseCertificate(crtBytes)
|
||||
if err != nil {
|
||||
return nil, nil, &apiError{errors.Wrap(err, "error parsing new server certificate"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
}
|
||||
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
|
||||
if err != nil {
|
||||
return nil, nil, &apiError{errors.Wrap(err, "error parsing intermediate certificate"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
http.StatusInternalServerError, apiCtx{}}
|
||||
}
|
||||
|
||||
return serverCert, caCert, nil
|
||||
|
@ -222,7 +222,7 @@ type RevokeOptions struct {
|
|||
//
|
||||
// TODO: Add OCSP and CRL support.
|
||||
func (a *Authority) Revoke(opts *RevokeOptions) error {
|
||||
errContext := context{
|
||||
errContext := apiCtx{
|
||||
"serialNumber": opts.Serial,
|
||||
"reasonCode": opts.ReasonCode,
|
||||
"reason": opts.Reason,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
|
@ -103,7 +104,8 @@ func TestSign(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
extraOpts, err := a.Authorize(token)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
extraOpts, err := a.Authorize(ctx, token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type signTest struct {
|
||||
|
@ -124,7 +126,7 @@ func TestSign(t *testing.T) {
|
|||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: invalid certificate request"),
|
||||
http.StatusBadRequest,
|
||||
context{"csr": csr, "signOptions": signOpts},
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -138,7 +140,7 @@ func TestSign(t *testing.T) {
|
|||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: invalid extra option type string"),
|
||||
http.StatusInternalServerError,
|
||||
context{"csr": csr, "signOptions": signOpts},
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -153,7 +155,7 @@ func TestSign(t *testing.T) {
|
|||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: default ASN1DN template cannot be nil"),
|
||||
http.StatusInternalServerError,
|
||||
context{"csr": csr, "signOptions": signOpts},
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -168,7 +170,7 @@ func TestSign(t *testing.T) {
|
|||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: error creating new leaf certificate"),
|
||||
http.StatusInternalServerError,
|
||||
context{"csr": csr, "signOptions": signOpts},
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -185,7 +187,7 @@ func TestSign(t *testing.T) {
|
|||
signOpts: _signOpts,
|
||||
err: &apiError{errors.New("sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"),
|
||||
http.StatusUnauthorized,
|
||||
context{"csr": csr, "signOptions": _signOpts},
|
||||
apiCtx{"csr": csr, "signOptions": _signOpts},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -200,7 +202,7 @@ func TestSign(t *testing.T) {
|
|||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
||||
http.StatusUnauthorized,
|
||||
context{"csr": csr, "signOptions": signOpts},
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -227,7 +229,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG
|
|||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: rsa key in CSR must be at least 2048 bits (256 bytes)"),
|
||||
http.StatusUnauthorized,
|
||||
context{"csr": csr, "signOptions": signOpts},
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -238,7 +240,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG
|
|||
storeCertificate: func(crt *x509.Certificate) error {
|
||||
return &apiError{errors.New("force"),
|
||||
http.StatusInternalServerError,
|
||||
context{"csr": csr, "signOptions": signOpts}}
|
||||
apiCtx{"csr": csr, "signOptions": signOpts}}
|
||||
},
|
||||
}
|
||||
return &signTest{
|
||||
|
@ -248,7 +250,7 @@ ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG
|
|||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: error storing certificate in db: force"),
|
||||
http.StatusInternalServerError,
|
||||
context{"csr": csr, "signOptions": signOpts},
|
||||
apiCtx{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -401,7 +403,7 @@ func TestRenew(t *testing.T) {
|
|||
auth: _a,
|
||||
crt: crt,
|
||||
err: &apiError{errors.New("error renewing certificate from existing server certificate"),
|
||||
http.StatusInternalServerError, context{}},
|
||||
http.StatusInternalServerError, apiCtx{}},
|
||||
}, nil
|
||||
},
|
||||
"fail-unauthorized": func() (*renewTest, error) {
|
||||
|
@ -596,7 +598,7 @@ func TestRevoke(t *testing.T) {
|
|||
validAudience := []string{"https://test.ca.smallstep.com/revoke"}
|
||||
now := time.Now().UTC()
|
||||
getCtx := func() map[string]interface{} {
|
||||
return context{
|
||||
return apiCtx{
|
||||
"serialNumber": "sn",
|
||||
"reasonCode": reasonCode,
|
||||
"reason": reason,
|
||||
|
|
22
ca/client.go
22
ca/client.go
|
@ -373,6 +373,28 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
|
|||
return &sign, nil
|
||||
}
|
||||
|
||||
// SignSSH performs the SSH certificate sign request to the CA and returns the
|
||||
// api.SignSSHResponse struct.
|
||||
func (c *Client) SignSSH(req *api.SignSSHRequest) (*api.SignSSHResponse, error) {
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshaling request")
|
||||
}
|
||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign-ssh"})
|
||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readError(resp.Body)
|
||||
}
|
||||
var sign api.SignSSHResponse
|
||||
if err := readJSON(resp.Body, &sign); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
}
|
||||
return &sign, nil
|
||||
}
|
||||
|
||||
// Renew performs the renew request to the CA and returns the api.SignResponse
|
||||
// struct.
|
||||
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
||||
|
|
|
@ -280,7 +280,11 @@ func stringifyFlag(f cli.Flag) string {
|
|||
usage := fv.FieldByName("Usage").String()
|
||||
placeholder := placeholderString.FindString(usage)
|
||||
if placeholder == "" {
|
||||
placeholder = "<value>"
|
||||
switch f.(type) {
|
||||
case cli.BoolFlag, cli.BoolTFlag:
|
||||
default:
|
||||
placeholder = "<value>"
|
||||
}
|
||||
}
|
||||
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue