Merge pull request #22 from smallstep/mariano/multiroot

Multiple roots and federation
This commit is contained in:
Mariano Cano 2019-01-14 18:15:33 -08:00 committed by GitHub
commit d0e0217955
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 1782 additions and 136 deletions

View file

@ -22,9 +22,11 @@ type Authority interface {
GetTLSOptions() *tlsutil.TLSOptions
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
GetEncryptedKey(kid string) (string, error)
GetRoots() (federation []*x509.Certificate, err error)
GetFederation() ([]*x509.Certificate, error)
}
// Certificate wraps a *x509.Certificate and adds the json.Marshaler interface.
@ -186,6 +188,16 @@ type SignResponse struct {
TLS *tls.ConnectionState `json:"-"`
}
// RootsResponse is the response object of the roots request.
type RootsResponse struct {
Certificates []Certificate `json:"crts"`
}
// FederationResponse is the response object of the federation request.
type FederationResponse struct {
Certificates []Certificate `json:"crts"`
}
// caHandler is the type used to implement the different CA HTTP endpoints.
type caHandler struct {
Authority Authority
@ -205,6 +217,8 @@ func (h *caHandler) Route(r Router) {
r.MethodFunc("POST", "/renew", h.Renew)
r.MethodFunc("GET", "/provisioners", h.Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
r.MethodFunc("GET", "/roots", h.Roots)
r.MethodFunc("GET", "/federation", h.Federation)
// For compatibility with old code:
r.MethodFunc("POST", "/re-sign", h.Renew)
}
@ -320,6 +334,44 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
JSON(w, &ProvisionerKeyResponse{key})
}
// Roots returns all the root certificates for the CA.
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
if err != nil {
WriteError(w, Forbidden(err))
return
}
certs := make([]Certificate, len(roots))
for i := range roots {
certs[i] = Certificate{roots[i]}
}
w.WriteHeader(http.StatusCreated)
JSON(w, &RootsResponse{
Certificates: certs,
})
}
// Federation returns all the public certificates in the federation.
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation()
if err != nil {
WriteError(w, Forbidden(err))
return
}
certs := make([]Certificate, len(federated))
for i := range federated {
certs[i] = Certificate{federated[i]}
}
w.WriteHeader(http.StatusCreated)
JSON(w, &FederationResponse{
Certificates: certs,
})
}
func parseCursor(r *http.Request) (cursor string, limit int, err error) {
q := r.URL.Query()
cursor = q.Get("cursor")

View file

@ -392,6 +392,8 @@ type mockAuthority struct {
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
getEncryptedKey func(kid string) (string, error)
getRoots func() ([]*x509.Certificate, error)
getFederation func() ([]*x509.Certificate, error)
}
func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) {
@ -443,6 +445,20 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
return m.ret1.(string), m.err
}
func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) {
if m.getRoots != nil {
return m.getRoots()
}
return m.ret1.([]*x509.Certificate), m.err
}
func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) {
if m.getFederation != nil {
return m.getFederation()
}
return m.ret1.([]*x509.Certificate), m.err
}
func Test_caHandler_Route(t *testing.T) {
type fields struct {
Authority Authority
@ -812,3 +828,95 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
})
}
}
func Test_caHandler_Roots(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
tests := []struct {
name string
tls *tls.ConnectionState
cert *x509.Certificate
root *x509.Certificate
err error
statusCode int
}{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
}
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Roots(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Roots unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected)
}
}
})
}
}
func Test_caHandler_Federation(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
tests := []struct {
name string
tls *tls.ConnectionState
cert *x509.Certificate
root *x509.Certificate
err error
statusCode int
}{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
}
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Federation(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Federation unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected)
}
}
})
}
}

View file

@ -2,7 +2,7 @@ package authority
import (
"crypto/sha256"
realx509 "crypto/x509"
"crypto/x509"
"encoding/hex"
"fmt"
"sync"
@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority"
// Authority implements the Certificate Authority internal interface.
type Authority struct {
config *Config
rootX509Crt *realx509.Certificate
rootX509Certs []*x509.Certificate
intermediateIdentity *x509util.Identity
validateOnce bool
certificates *sync.Map
@ -79,15 +79,29 @@ func (a *Authority) init() error {
}
var err error
// First load the root using our modified pem/x509 package.
a.rootX509Crt, err = pemutil.ReadCertificate(a.config.Root)
// Load the root certificates and add them to the certificate store
a.rootX509Certs = make([]*x509.Certificate, len(a.config.Root))
for i, path := range a.config.Root {
crt, err := pemutil.ReadCertificate(path)
if err != nil {
return err
}
// Add root certificate to the certificate map
sum := sha256.Sum256(a.rootX509Crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
sum := sha256.Sum256(crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
a.rootX509Certs[i] = crt
}
// Add federated roots
for _, path := range a.config.FederatedRoots {
crt, err := pemutil.ReadCertificate(path)
if err != nil {
return err
}
sum := sha256.Sum256(crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
}
// Decrypt and load intermediate public / private key pair.
if len(a.config.Password) > 0 {

View file

@ -38,7 +38,7 @@ func testAuthority(t *testing.T) *Authority {
}
c := &Config{
Address: "127.0.0.1:443",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.ca.smallstep.com"},
@ -68,7 +68,7 @@ func TestAuthorityNew(t *testing.T) {
"fail bad root": func(t *testing.T) *newTest {
c, err := LoadConfiguration("../ca/testdata/ca.json")
assert.FatalError(t, err)
c.Root = "foo"
c.Root = []string{"foo"}
return &newTest{
config: c,
err: errors.New("open foo failed: no such file or directory"),
@ -105,10 +105,10 @@ func TestAuthorityNew(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
sum := sha256.Sum256(auth.rootX509Crt.Raw)
sum := sha256.Sum256(auth.rootX509Certs[0].Raw)
root, ok := auth.certificates.Load(hex.EncodeToString(sum[:]))
assert.Fatal(t, ok)
assert.Equals(t, auth.rootX509Crt, root)
assert.Equals(t, auth.rootX509Certs[0], root)
assert.True(t, auth.initOnce)
assert.NotNil(t, auth.intermediateIdentity)

View file

@ -33,38 +33,10 @@ var (
}
)
type duration struct {
time.Duration
}
// MarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) UnmarshalJSON(data []byte) (err error) {
var s string
if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
if d.Duration, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s)
}
return
}
// Config represents the CA configuration and it's mapped to a JSON object.
type Config struct {
Root string `json:"root"`
Root multiString `json:"root"`
FederatedRoots []string `json:"federatedRoots"`
IntermediateCert string `json:"crt"`
IntermediateKey string `json:"key"`
Address string `json:"address"`
@ -145,7 +117,7 @@ func (c *Config) Validate() error {
case c.Address == "":
return errors.New("address cannot be empty")
case c.Root == "":
case c.Root.HasEmpties():
return errors.New("root cannot be empty")
case c.IntermediateCert == "":

View file

@ -40,7 +40,7 @@ func TestConfigValidate(t *testing.T) {
"empty-address": func(t *testing.T) ConfigValidateTest {
return ConfigValidateTest{
config: &Config{
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
@ -54,7 +54,7 @@ func TestConfigValidate(t *testing.T) {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
@ -81,7 +81,7 @@ func TestConfigValidate(t *testing.T) {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1:443",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
@ -94,7 +94,7 @@ func TestConfigValidate(t *testing.T) {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1:443",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
DNSNames: []string{"test.smallstep.com"},
Password: "pass",
@ -107,7 +107,7 @@ func TestConfigValidate(t *testing.T) {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1:443",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
Password: "pass",
@ -120,7 +120,7 @@ func TestConfigValidate(t *testing.T) {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1:443",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
@ -134,7 +134,7 @@ func TestConfigValidate(t *testing.T) {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1:443",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
@ -149,7 +149,7 @@ func TestConfigValidate(t *testing.T) {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1:443",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},
@ -178,7 +178,7 @@ func TestConfigValidate(t *testing.T) {
return ConfigValidateTest{
config: &Config{
Address: "127.0.0.1:443",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.smallstep.com"},

View file

@ -17,7 +17,7 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
crt, ok := val.(*x509.Certificate)
if !ok {
return nil, &apiError{errors.Errorf("stored value is not a *cryto/x509.Certificate"),
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, context{}}
}
return crt, nil
@ -25,5 +25,39 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
// GetRootCertificate returns the server root certificate.
func (a *Authority) GetRootCertificate() *x509.Certificate {
return a.rootX509Crt
return a.rootX509Certs[0]
}
// GetRootCertificates returns the server root certificates.
//
// In the Authority interface we also have a similar method, GetRoots, at the
// moment the functionality of these two methods are almost identical, but this
// method is intended to be used internally by CA HTTP server to load the roots
// that will be set in the tls.Config while GetRoots will be used by the
// Authority interface and might have extra checks in the future.
func (a *Authority) GetRootCertificates() []*x509.Certificate {
return a.rootX509Certs
}
// GetRoots returns all the root certificates for this CA.
// This method implements the Authority interface.
func (a *Authority) GetRoots() ([]*x509.Certificate, error) {
return a.rootX509Certs, nil
}
// GetFederation returns all the root certificates in the federation.
// This method implements the Authority interface.
func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) {
a.certificates.Range(func(k, v interface{}) bool {
crt, ok := v.(*x509.Certificate)
if !ok {
federation = nil
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, context{}}
return false
}
federation = append(federation, crt)
return true
})
return
}

View file

@ -1,11 +1,14 @@
package authority
import (
"crypto/x509"
"net/http"
"reflect"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/pemutil"
)
func TestRoot(t *testing.T) {
@ -17,7 +20,7 @@ func TestRoot(t *testing.T) {
err *apiError
}{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *cryto/x509.Certificate"), http.StatusInternalServerError, context{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
}
@ -37,9 +40,116 @@ func TestRoot(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, crt, a.rootX509Crt)
assert.Equals(t, crt, a.rootX509Certs[0])
}
}
})
}
}
func TestAuthority_GetRootCertificate(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
want *x509.Certificate
}{
{"ok", cert},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
if got := a.GetRootCertificate(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetRootCertificate() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthority_GetRootCertificates(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
want []*x509.Certificate
}{
{"ok", []*x509.Certificate{cert}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
if got := a.GetRootCertificates(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetRootCertificates() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthority_GetRoots(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
want []*x509.Certificate
wantErr bool
}{
{"ok", []*x509.Certificate{cert}, false},
}
for _, tt := range tests {
a := testAuthority(t)
t.Run(tt.name, func(t *testing.T) {
got, err := a.GetRoots()
if (err != nil) != tt.wantErr {
t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetRoots() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthority_GetFederation(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
wantFederation []*x509.Certificate
wantErr bool
fn func(a *Authority)
}{
{"ok", []*x509.Certificate{cert}, false, nil},
{"fail", nil, true, func(a *Authority) {
a.certificates.Store("foo", "bar")
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
if tt.fn != nil {
tt.fn(a)
}
gotFederation, err := a.GetFederation()
if (err != nil) != tt.wantErr {
t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotFederation, tt.wantFederation) {
t.Errorf("Authority.GetFederation() = %v, want %v", gotFederation, tt.wantFederation)
}
})
}
}

104
authority/types.go Normal file
View file

@ -0,0 +1,104 @@
package authority
import (
"encoding/json"
"time"
"github.com/pkg/errors"
)
type duration struct {
time.Duration
}
// MarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) UnmarshalJSON(data []byte) (err error) {
var s string
if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
if d.Duration, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s)
}
return
}
// multiString represents a type that can be encoded/decoded in JSON as a single
// string or an array of strings.
type multiString []string
// First returns the first element of a multiString. It will return an empty
// string if the multistring is empty.
func (s multiString) First() string {
if len(s) > 0 {
return s[0]
}
return ""
}
// HasEmpties returns `true` if any string in the array is empty.
func (s multiString) HasEmpties() bool {
if len(s) == 0 {
return true
}
for _, ss := range s {
if len(ss) == 0 {
return true
}
}
return false
}
// MarshalJSON marshals the multistring as a string or a slice of strings . With
// 0 elements it will return the empty string, with 1 element a regular string,
// otherwise a slice of strings.
func (s multiString) MarshalJSON() ([]byte, error) {
switch len(s) {
case 0:
return []byte(`""`), nil
case 1:
return json.Marshal(s[0])
default:
return json.Marshal([]string(s))
}
}
// UnmarshalJSON parses a string or a slice and sets it to the multiString.
func (s *multiString) UnmarshalJSON(data []byte) error {
if s == nil {
return errors.New("multiString cannot be nil")
}
if len(data) == 0 {
*s = nil
return nil
}
// Parse string
if data[0] == '"' {
var str string
if err := json.Unmarshal(data, &str); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
*s = []string{str}
return nil
}
// Parse array
var ss []string
if err := json.Unmarshal(data, &ss); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
*s = ss
return nil
}

103
authority/types_test.go Normal file
View file

@ -0,0 +1,103 @@
package authority
import (
"reflect"
"testing"
)
func Test_multiString_First(t *testing.T) {
tests := []struct {
name string
s multiString
want string
}{
{"empty", multiString{}, ""},
{"string", multiString{"one"}, "one"},
{"slice", multiString{"one", "two"}, "one"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.s.First(); got != tt.want {
t.Errorf("multiString.First() = %v, want %v", got, tt.want)
}
})
}
}
func Test_multiString_Empties(t *testing.T) {
tests := []struct {
name string
s multiString
want bool
}{
{"empty", multiString{}, true},
{"string", multiString{"one"}, false},
{"empty string", multiString{""}, true},
{"slice", multiString{"one", "two"}, false},
{"empty slice", multiString{"one", ""}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.s.HasEmpties(); got != tt.want {
t.Errorf("multiString.Empties() = %v, want %v", got, tt.want)
}
})
}
}
func Test_multiString_MarshalJSON(t *testing.T) {
tests := []struct {
name string
s multiString
want []byte
wantErr bool
}{
{"empty", []string{}, []byte(`""`), false},
{"string", []string{"a string"}, []byte(`"a string"`), false},
{"slice", []string{"string one", "string two"}, []byte(`["string one","string two"]`), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.s.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("multiString.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("multiString.MarshalJSON() = %v, want %v", got, tt.want)
}
})
}
}
func Test_multiString_UnmarshalJSON(t *testing.T) {
type args struct {
data []byte
}
tests := []struct {
name string
s *multiString
args args
want *multiString
wantErr bool
}{
{"empty", new(multiString), args{[]byte{}}, new(multiString), false},
{"empty string", new(multiString), args{[]byte(`""`)}, &multiString{""}, false},
{"string", new(multiString), args{[]byte(`"a string"`)}, &multiString{"a string"}, false},
{"slice", new(multiString), args{[]byte(`["string one","string two"]`)}, &multiString{"string one", "string two"}, false},
{"error", new(multiString), args{[]byte(`["123",123]`)}, new(multiString), true},
{"nil", nil, args{nil}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.s.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("multiString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(tt.s, tt.want) {
t.Errorf("multiString.UnmarshalJSON() = %v, want %v", tt.s, tt.want)
}
})
}
}

View file

@ -87,6 +87,9 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio
return nil, err
}
// Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs
options = append(options, AddRootsToCAs())
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
if err != nil {
return nil, err
@ -130,6 +133,9 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*
return nil, err
}
// Make sure the tlsConfig have all supported roots on RootCAs
options = append(options, AddRootsToRootCAs())
transport, err := client.Transport(ctx, sign, pk, options...)
if err != nil {
return nil, err

View file

@ -3,12 +3,15 @@ package ca
import (
"context"
"crypto/tls"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
@ -18,6 +21,24 @@ import (
"gopkg.in/square/go-jose.v2/jwt"
)
func newLocalListener() net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
panic(errors.Wrap(err, "failed to listen on a port"))
}
}
return l
}
func setMinCertDuration(d time.Duration) func() {
tmp := minCertDuration
minCertDuration = 1 * time.Second
return func() {
minCertDuration = tmp
}
}
func startCABootstrapServer() *httptest.Server {
config, err := authority.LoadConfiguration("testdata/ca.json")
if err != nil {
@ -115,8 +136,10 @@ func TestBootstrap(t *testing.T) {
if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) {
t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint)
}
if !reflect.DeepEqual(got.certPool, tt.want.certPool) {
t.Errorf("Bootstrap() certPool = %v, want %v", got.certPool, tt.want.certPool)
gotTR := got.client.Transport.(*http.Transport)
wantTR := tt.want.client.Transport.(*http.Transport)
if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) {
t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs)
}
}
}
@ -267,3 +290,147 @@ func TestBootstrapClient(t *testing.T) {
})
}
}
func TestBootstrapClientServerRotation(t *testing.T) {
reset := setMinCertDuration(1 * time.Second)
defer reset()
// Configuration with current root
config, err := authority.LoadConfiguration("testdata/rotate-ca-0.json")
if err != nil {
t.Fatal(err)
}
// Get local address
listener := newLocalListener()
config.Address = listener.Addr().String()
caURL := "https://" + listener.Addr().String()
// Start CA server
ca, err := New(config)
if err != nil {
t.Fatal(err)
}
go func() {
ca.srv.Serve(listener)
}()
defer ca.Stop()
time.Sleep(1 * time.Second)
// Create bootstrap server
token := generateBootstrapToken(caURL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
server, err := BootstrapServer(context.Background(), token, &http.Server{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("ok"))
}),
}, RequireAndVerifyClientCert())
if err != nil {
t.Fatal(err)
}
listener = newLocalListener()
srvURL := "https://" + listener.Addr().String()
go func() {
server.ServeTLS(listener, "", "")
}()
defer server.Close()
// Create bootstrap client
token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
client, err := BootstrapClient(context.Background(), token)
if err != nil {
t.Errorf("BootstrapClient() error = %v", err)
return
}
// doTest does a request that requires mTLS
doTest := func(client *http.Client) error {
// test with ca
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
if err != nil {
return errors.Wrap(err, "client.Post() failed")
}
var renew api.SignResponse
if err := readJSON(resp.Body, &renew); err != nil {
return errors.Wrap(err, "client.Post() error reading response")
}
if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil {
return errors.New("client.Post() unexpected response found")
}
// test with bootstrap server
resp, err = client.Get(srvURL)
if err != nil {
return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errors.Wrap(err, "client.Get() error reading response")
}
if string(b) != "ok" {
return errors.New("client.Get() unexpected response found")
}
return nil
}
// Test with default root
if err := doTest(client); err != nil {
t.Errorf("Test with rotate-ca-0.json failed: %v", err)
}
// wait for renew
time.Sleep(5 * time.Second)
// Reload with configuration with current and future root
ca.opts.configFile = "testdata/rotate-ca-1.json"
if err := doReload(ca); err != nil {
t.Errorf("ca.Reload() error = %v", err)
return
}
if err := doTest(client); err != nil {
t.Errorf("Test with rotate-ca-1.json failed: %v", err)
}
// wait for renew
time.Sleep(5 * time.Second)
// Reload with new and old root
ca.opts.configFile = "testdata/rotate-ca-2.json"
if err := doReload(ca); err != nil {
t.Errorf("ca.Reload() error = %v", err)
return
}
if err := doTest(client); err != nil {
t.Errorf("Test with rotate-ca-2.json failed: %v", err)
}
// wait for renew
time.Sleep(5 * time.Second)
// Reload with pnly the new root
ca.opts.configFile = "testdata/rotate-ca-3.json"
if err := doReload(ca); err != nil {
t.Errorf("ca.Reload() error = %v", err)
return
}
if err := doTest(client); err != nil {
t.Errorf("Test with rotate-ca-3.json failed: %v", err)
}
}
// doReload uses the reload implementation but overwrites the new address with
// the one being used.
func doReload(ca *CA) error {
config, err := authority.LoadConfiguration(ca.opts.configFile)
if err != nil {
return errors.Wrap(err, "error reloading ca")
}
newCA, err := New(config, WithPassword(ca.opts.password), WithConfigFile(ca.opts.configFile))
if err != nil {
return errors.Wrap(err, "error reloading ca")
}
// Use same address in new server
newCA.srv.Addr = ca.srv.Addr
return ca.srv.Reload(newCA.srv)
}

View file

@ -176,7 +176,9 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
}
certPool := x509.NewCertPool()
certPool.AddCert(auth.GetRootCertificate())
for _, crt := range auth.GetRootCertificates() {
certPool.AddCert(crt)
}
// GetCertificate will only be called if the client supplies SNI
// information or if tlsConfig.Certificates is empty.

View file

@ -23,7 +23,6 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"golang.org/x/net/http2"
"gopkg.in/square/go-jose.v2/jwt"
)
@ -239,7 +238,6 @@ func WithProvisionerLimit(limit int) ProvisionerOption {
type Client struct {
client *http.Client
endpoint *url.URL
certPool *x509.CertPool
}
// NewClient creates a new Client with the given endpoint and options.
@ -258,23 +256,11 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
return nil, err
}
var cp *x509.CertPool
switch tr := tr.(type) {
case *http.Transport:
if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil {
cp = tr.TLSClientConfig.RootCAs
}
case *http2.Transport:
if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil {
cp = tr.TLSClientConfig.RootCAs
}
}
return &Client{
client: &http.Client{
Transport: tr,
},
endpoint: u,
certPool: cp,
}, nil
}
@ -413,6 +399,42 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error)
return &key, nil
}
// Roots performs the get roots request to the CA and returns the
// api.RootsResponse struct.
func (c *Client) Roots() (*api.RootsResponse, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var roots api.RootsResponse
if err := readJSON(resp.Body, &roots); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &roots, nil
}
// Federation performs the get federation request to the CA and returns the
// api.FederationResponse struct.
func (c *Client) Federation() (*api.FederationResponse, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var federation api.FederationResponse
if err := readJSON(resp.Body, &federation); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &federation, nil
}
// CreateSignRequest is a helper function that given an x509 OTT returns a
// simple but secure sign request as well as the private key used.
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {

View file

@ -512,6 +512,128 @@ func TestClient_ProvisionerKey(t *testing.T) {
}
}
func TestClient_Roots(t *testing.T) {
ok := &api.RootsResponse{
Certificates: []api.Certificate{
{Certificate: parseCertificate(rootPEM)},
},
}
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.Roots()
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.Roots() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Roots() error = %v, want %v", err, tt.response)
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Federation(t *testing.T) {
ok := &api.FederationResponse{
Certificates: []api.Certificate{
{Certificate: parseCertificate(rootPEM)},
},
}
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.Federation()
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Federation() error = %v, want %v", err, tt.response)
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
}
}
})
}
}
func Test_parseEndpoint(t *testing.T) {
expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"}
expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"}

View file

@ -14,6 +14,8 @@ import (
// certificate.
type RenewFunc func() (*tls.Certificate, error)
var minCertDuration = time.Minute
// TLSRenewer automatically renews a tls certificate using a RenewFunc.
type TLSRenewer struct {
sync.RWMutex
@ -58,8 +60,8 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
}
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore)
if period < time.Minute {
return nil, errors.Errorf("period must be greater than or equal to 1 Minute, but got %v.", period)
if period < minCertDuration {
return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period)
}
// By default we will try to renew the cert before 2/3 of the validity
// period have expired.

4
ca/testdata/ca.json vendored
View file

@ -1,5 +1,6 @@
{
"root": "../ca/testdata/secrets/root_ca.crt",
"federatedRoots": ["../ca/testdata/secrets/federated_ca.crt"],
"crt": "../ca/testdata/secrets/intermediate_ca.crt",
"key": "../ca/testdata/secrets/intermediate_ca_key",
"password": "password",
@ -17,7 +18,6 @@
]
},
"authority": {
"minCertDuration": "1m",
"provisioners": [
{
"name": "max",
@ -72,7 +72,7 @@
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
},
"claims": {
"minTLSCertDuration": "30s"
"minTLSCertDuration": "1s"
}
}, {
"name": "mariano",

46
ca/testdata/rotate-ca-0.json vendored Normal file
View file

@ -0,0 +1,46 @@
{
"root": "testdata/secrets/root_ca.crt",
"crt": "testdata/secrets/intermediate_ca.crt",
"key": "testdata/secrets/intermediate_ca_key",
"password": "password",
"address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"],
"logger": {"format": "text"},
"tls": {
"minVersion": 1.2,
"maxVersion": 1.2,
"renegotiation": false,
"cipherSuites": [
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
]
},
"authority": {
"provisioners": [
{
"name": "mariano",
"type": "jwk",
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ",
"key": {
"use": "sig",
"kty": "EC",
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
"crv": "P-256",
"alg": "ES256",
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
},
"claims": {
"minTLSCertDuration": "1s",
"defaultTLSCertDuration": "5s"
}
}
],
"template": {
"country": "US",
"locality": "San Francisco",
"organization": "Smallstep"
}
}
}

46
ca/testdata/rotate-ca-1.json vendored Normal file
View file

@ -0,0 +1,46 @@
{
"root": ["testdata/secrets/root_ca.crt", "testdata/rotated/root_ca.crt"],
"crt": "testdata/secrets/intermediate_ca.crt",
"key": "testdata/secrets/intermediate_ca_key",
"password": "password",
"address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"],
"logger": {"format": "text"},
"tls": {
"minVersion": 1.2,
"maxVersion": 1.2,
"renegotiation": false,
"cipherSuites": [
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
]
},
"authority": {
"provisioners": [
{
"name": "mariano",
"type": "jwk",
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ",
"key": {
"use": "sig",
"kty": "EC",
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
"crv": "P-256",
"alg": "ES256",
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
},
"claims": {
"minTLSCertDuration": "1s",
"defaultTLSCertDuration": "5s"
}
}
],
"template": {
"country": "US",
"locality": "San Francisco",
"organization": "Smallstep"
}
}
}

46
ca/testdata/rotate-ca-2.json vendored Normal file
View file

@ -0,0 +1,46 @@
{
"root": ["testdata/rotated/root_ca.crt", "testdata/secrets/root_ca.crt"],
"crt": "testdata/rotated/intermediate_ca.crt",
"key": "testdata/rotated/intermediate_ca_key",
"password": "asdf",
"address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"],
"logger": {"format": "text"},
"tls": {
"minVersion": 1.2,
"maxVersion": 1.2,
"renegotiation": false,
"cipherSuites": [
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
]
},
"authority": {
"provisioners": [
{
"name": "mariano",
"type": "jwk",
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ",
"key": {
"use": "sig",
"kty": "EC",
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
"crv": "P-256",
"alg": "ES256",
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
},
"claims": {
"minTLSCertDuration": "1s",
"defaultTLSCertDuration": "5s"
}
}
],
"template": {
"country": "US",
"locality": "San Francisco",
"organization": "Smallstep"
}
}
}

46
ca/testdata/rotate-ca-3.json vendored Normal file
View file

@ -0,0 +1,46 @@
{
"root": "testdata/rotated/root_ca.crt",
"crt": "testdata/rotated/intermediate_ca.crt",
"key": "testdata/rotated/intermediate_ca_key",
"password": "asdf",
"address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"],
"logger": {"format": "text"},
"tls": {
"minVersion": 1.2,
"maxVersion": 1.2,
"renegotiation": false,
"cipherSuites": [
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
]
},
"authority": {
"provisioners": [
{
"name": "mariano",
"type": "jwk",
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ",
"key": {
"use": "sig",
"kty": "EC",
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
"crv": "P-256",
"alg": "ES256",
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
},
"claims": {
"minTLSCertDuration": "1s",
"defaultTLSCertDuration": "5s"
}
}
],
"template": {
"country": "US",
"locality": "San Francisco",
"organization": "Smallstep"
}
}
}

12
ca/testdata/rotated/intermediate_ca.crt vendored Normal file
View file

@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIBxTCCAWugAwIBAgIQLIY6MR/1fBRQY4ZTTsPAJjAKBggqhkjOPQQDAjAcMRow
GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTAxMDcyMDExMzBaFw0yOTAx
MDQyMDExMzBaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew
WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARgtjL/KLNpdq81YYWaek1lrkPM/QF1
m+ujwv5jya21fAXljdBLh6m2xco1GPfwPBbwUGlNOdEqE9Nq3Qx3ngPKo4GGMIGD
MA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw
EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUqixeZ/K1HW9N6SVw7ONya98S
u8UwHwYDVR0jBBgwFoAUgIzlCLxh/RlwEany4JQHOorLAIEwCgYIKoZIzj0EAwID
SAAwRQIgdGX6lxThrKlt3v+3HJZlaWdmoeQ3vYwpJb9uHExZdVYCIQDCxsdI8EnB
bxjnJscbT4zvqVsq6AmycdbFwgy8RIeVzg==
-----END CERTIFICATE-----

View file

@ -0,0 +1,8 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-256-CBC,7dcc0a8c1d73c8d438184e0928875329
r6yrQrHg6zBZRSjQpe8RzyQALEfiT3/8lMvvPu3BX6yign5skMfCVMXZhzbmAwmR
BJBIX+5hkudR2VN+hrsOyuU7FvIk4gx2c8buIlFObfYXIml0mpuThfm52ciAtOTE
S0hkfYvPcOAjzaDZ+8Po/mYhkODgyvijogn4ioTF/Ss=
-----END EC PRIVATE KEY-----

11
ca/testdata/rotated/root_ca.crt vendored Normal file
View file

@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa
MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw
MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG
SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL
jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB
/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR
qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z
YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc
Sw==
-----END CERTIFICATE-----

8
ca/testdata/rotated/root_ca_key vendored Normal file
View file

@ -0,0 +1,8 @@
-----BEGIN EC PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-256-CBC,8ce79d28601b9809905ef7c362a20749
H+pTTL3B5fLYycgHLxFOW0fZsayr7Y+BW8THKf12h8dk0/eOE1wNoX2TuMtpbZgO
lMJdFPL+SAPCCmuZOZIcQDejRHVcYBq1wvrrnw/yfVawXC4xze+J4Y+q0J2WY+rM
xcLGlEOIRZkvdDVGmSitEZBl0Ibk0p9tG++7QGqAvnk=
-----END EC PRIVATE KEY-----

11
ca/testdata/secrets/federated_ca.crt vendored Normal file
View file

@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa
MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw
MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG
SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL
jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB
/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR
qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z
YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc
Sw==
-----END CERTIFICATE-----

View file

@ -41,7 +41,8 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
}
// Apply options if given
if err := setTLSOptions(tlsConfig, options); err != nil {
tlsCtx := newTLSOptionCtx(c, tlsConfig)
if err := tlsCtx.apply(options); err != nil {
return nil, err
}
@ -50,7 +51,10 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
if err != nil {
return nil, err
}
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
// Update client transport
c.client.Transport = tr
// Start renewer
renewer.RunContext(ctx)
@ -87,7 +91,8 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
}
// Apply options if given
if err := setTLSOptions(tlsConfig, options); err != nil {
tlsCtx := newTLSOptionCtx(c, tlsConfig)
if err := tlsCtx.apply(options); err != nil {
return nil, err
}
@ -96,7 +101,10 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
if err != nil {
return nil, err
}
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
// Update client transport
c.client.Transport = tr
// Start renewer
renewer.RunContext(ctx)
@ -238,8 +246,13 @@ func getPEM(i interface{}) ([]byte, error) {
return pem.EncodeToMemory(block), nil
}
func getRenewFunc(client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
return func() (*tls.Certificate, error) {
// Get updated list of roots
if err := ctx.applyRenew(); err != nil {
return nil, err
}
// Get new certificate
sign, err := client.Renew(tr)
if err != nil {
return nil, err

View file

@ -6,13 +6,35 @@ import (
)
// TLSOption defines the type of a function that modifies a tls.Config.
type TLSOption func(c *tls.Config) error
type TLSOption func(ctx *TLSOptionCtx) error
// setTLSOptions takes one or more option function and applies them in order to
// a tls.Config.
func setTLSOptions(c *tls.Config, options []TLSOption) error {
for _, opt := range options {
if err := opt(c); err != nil {
// TLSOptionCtx is the context modified on TLSOption methods.
type TLSOptionCtx struct {
Client *Client
Config *tls.Config
OnRenewFunc []TLSOption
}
// newTLSOptionCtx creates the TLSOption context.
func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx {
return &TLSOptionCtx{
Client: c,
Config: config,
}
}
func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
for _, fn := range options {
if err := fn(ctx); err != nil {
return err
}
}
return nil
}
func (ctx *TLSOptionCtx) applyRenew() error {
for _, fn := range ctx.OnRenewFunc {
if err := fn(ctx); err != nil {
return err
}
}
@ -22,8 +44,8 @@ func setTLSOptions(c *tls.Config, options []TLSOption) error {
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
// a valid TLS client certificate. This is the default option for mTLS servers.
func RequireAndVerifyClientCert() TLSOption {
return func(c *tls.Config) error {
c.ClientAuth = tls.RequireAndVerifyClientCert
return func(ctx *TLSOptionCtx) error {
ctx.Config.ClientAuth = tls.RequireAndVerifyClientCert
return nil
}
}
@ -31,8 +53,8 @@ func RequireAndVerifyClientCert() TLSOption {
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
// a TLS client certificate if it is provided. It does not requires a certificate.
func VerifyClientCertIfGiven() TLSOption {
return func(c *tls.Config) error {
c.ClientAuth = tls.VerifyClientCertIfGiven
return func(ctx *TLSOptionCtx) error {
ctx.Config.ClientAuth = tls.VerifyClientCertIfGiven
return nil
}
}
@ -41,11 +63,11 @@ func VerifyClientCertIfGiven() TLSOption {
// defines the set of root certificate authorities that clients use when
// verifying server certificates.
func AddRootCA(cert *x509.Certificate) TLSOption {
return func(c *tls.Config) error {
if c.RootCAs == nil {
c.RootCAs = x509.NewCertPool()
return func(ctx *TLSOptionCtx) error {
if ctx.Config.RootCAs == nil {
ctx.Config.RootCAs = x509.NewCertPool()
}
c.RootCAs.AddCert(cert)
ctx.Config.RootCAs.AddCert(cert)
return nil
}
}
@ -54,11 +76,163 @@ func AddRootCA(cert *x509.Certificate) TLSOption {
// defines the set of root certificate authorities that servers use if required
// to verify a client certificate by the policy in ClientAuth.
func AddClientCA(cert *x509.Certificate) TLSOption {
return func(c *tls.Config) error {
if c.ClientCAs == nil {
c.ClientCAs = x509.NewCertPool()
return func(ctx *TLSOptionCtx) error {
if ctx.Config.ClientCAs == nil {
ctx.Config.ClientCAs = x509.NewCertPool()
}
c.ClientCAs.AddCert(cert)
ctx.Config.ClientCAs.AddCert(cert)
return nil
}
}
// AddRootsToRootCAs does a roots request and adds to the tls.Config RootCAs all
// the certificates in the response. RootCAs defines the set of root certificate
// authorities that clients use when verifying server certificates.
//
// BootstrapServer and BootstrapClient methods include this option by default.
func AddRootsToRootCAs() TLSOption {
fn := func(ctx *TLSOptionCtx) error {
certs, err := ctx.Client.Roots()
if err != nil {
return err
}
if ctx.Config.RootCAs == nil {
ctx.Config.RootCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.RootCAs.AddCert(cert.Certificate)
}
return nil
}
return func(ctx *TLSOptionCtx) error {
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
return fn(ctx)
}
}
// AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs
// all the certificates in the response. ClientCAs defines the set of root
// certificate authorities that servers use if required to verify a client
// certificate by the policy in ClientAuth.
//
// BootstrapServer method includes this option by default.
func AddRootsToClientCAs() TLSOption {
fn := func(ctx *TLSOptionCtx) error {
certs, err := ctx.Client.Roots()
if err != nil {
return err
}
if ctx.Config.ClientCAs == nil {
ctx.Config.ClientCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.ClientCAs.AddCert(cert.Certificate)
}
return nil
}
return func(ctx *TLSOptionCtx) error {
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
return fn(ctx)
}
}
// AddFederationToRootCAs does a federation request and adds to the tls.Config
// RootCAs all the certificates in the response. RootCAs defines the set of root
// certificate authorities that clients use when verifying server certificates.
func AddFederationToRootCAs() TLSOption {
fn := func(ctx *TLSOptionCtx) error {
certs, err := ctx.Client.Federation()
if err != nil {
return err
}
if ctx.Config.RootCAs == nil {
ctx.Config.RootCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.RootCAs.AddCert(cert.Certificate)
}
return nil
}
return func(ctx *TLSOptionCtx) error {
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
return fn(ctx)
}
}
// AddFederationToClientCAs does a federation request and adds to the tls.Config
// ClientCAs all the certificates in the response. ClientCAs defines the set of
// root certificate authorities that servers use if required to verify a client
// certificate by the policy in ClientAuth.
func AddFederationToClientCAs() TLSOption {
fn := func(ctx *TLSOptionCtx) error {
certs, err := ctx.Client.Federation()
if err != nil {
return err
}
if ctx.Config.ClientCAs == nil {
ctx.Config.ClientCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.ClientCAs.AddCert(cert.Certificate)
}
return nil
}
return func(ctx *TLSOptionCtx) error {
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
return fn(ctx)
}
}
// AddRootsToCAs does a roots request and adds the resulting certs to the
// tls.Config RootCAs and ClientCAs. Combines the functionality of
// AddRootsToRootCAs and AddRootsToClientCAs.
func AddRootsToCAs() TLSOption {
fn := func(ctx *TLSOptionCtx) error {
certs, err := ctx.Client.Roots()
if err != nil {
return err
}
if ctx.Config.ClientCAs == nil {
ctx.Config.ClientCAs = x509.NewCertPool()
}
if ctx.Config.RootCAs == nil {
ctx.Config.RootCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.ClientCAs.AddCert(cert.Certificate)
ctx.Config.RootCAs.AddCert(cert.Certificate)
}
return nil
}
return func(ctx *TLSOptionCtx) error {
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
return fn(ctx)
}
}
// AddFederationToCAs does a federation request and adds the resulting certs to the
// tls.Config RootCAs and ClientCAs. Combines the functionality of
// AddFederationToRootCAs and AddFederationToClientCAs.
func AddFederationToCAs() TLSOption {
fn := func(ctx *TLSOptionCtx) error {
certs, err := ctx.Client.Federation()
if err != nil {
return err
}
if ctx.Config.ClientCAs == nil {
ctx.Config.ClientCAs = x509.NewCertPool()
}
if ctx.Config.RootCAs == nil {
ctx.Config.RootCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.ClientCAs.AddCert(cert.Certificate)
ctx.Config.RootCAs.AddCert(cert.Certificate)
}
return nil
}
return func(ctx *TLSOptionCtx) error {
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
return fn(ctx)
}
}

View file

@ -4,33 +4,69 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"sort"
"testing"
)
func Test_setTLSOptions(t *testing.T) {
fail := func() TLSOption {
return func(c *tls.Config) error {
return fmt.Errorf("an error")
}
func Test_newTLSOptionCtx(t *testing.T) {
client, err := NewClient("https://ca.smallstep.com", WithTransport(http.DefaultTransport))
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
type args struct {
c *tls.Config
options []TLSOption
c *Client
config *tls.Config
}
tests := []struct {
name string
args args
wantErr bool
want *TLSOptionCtx
}{
{"ok", args{&tls.Config{}, []TLSOption{RequireAndVerifyClientCert()}}, false},
{"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false},
{"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
{"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) {
t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want)
}
})
}
}
func TestTLSOptionCtx_apply(t *testing.T) {
fail := func() TLSOption {
return func(ctx *TLSOptionCtx) error {
return fmt.Errorf("an error")
}
}
type fields struct {
Config *tls.Config
}
type args struct {
options []TLSOption
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&tls.Config{}}, args{[]TLSOption{RequireAndVerifyClientCert()}}, false},
{"ok", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven()}}, false},
{"fail", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &TLSOptionCtx{
Config: tt.fields.Config,
}
if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr {
t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
@ -45,13 +81,15 @@ func TestRequireAndVerifyClientCert(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{}
if err := RequireAndVerifyClientCert()(got); err != nil {
ctx := &TLSOptionCtx{
Config: &tls.Config{},
}
if err := RequireAndVerifyClientCert()(ctx); err != nil {
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("RequireAndVerifyClientCert() = %v, want %v", got, tt.want)
if !reflect.DeepEqual(ctx.Config, tt.want) {
t.Errorf("RequireAndVerifyClientCert() = %v, want %v", ctx.Config, tt.want)
}
})
}
@ -66,13 +104,15 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{}
if err := VerifyClientCertIfGiven()(got); err != nil {
ctx := &TLSOptionCtx{
Config: &tls.Config{},
}
if err := VerifyClientCertIfGiven()(ctx); err != nil {
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("VerifyClientCertIfGiven() = %v, want %v", got, tt.want)
if !reflect.DeepEqual(ctx.Config, tt.want) {
t.Errorf("VerifyClientCertIfGiven() = %v, want %v", ctx.Config, tt.want)
}
})
}
@ -95,13 +135,15 @@ func TestAddRootCA(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{}
if err := AddRootCA(tt.args.cert)(got); err != nil {
ctx := &TLSOptionCtx{
Config: &tls.Config{},
}
if err := AddRootCA(tt.args.cert)(ctx); err != nil {
t.Errorf("AddRootCA() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("AddRootCA() = %v, want %v", got, tt.want)
if !reflect.DeepEqual(ctx.Config, tt.want) {
t.Errorf("AddRootCA() = %v, want %v", ctx.Config, tt.want)
}
})
}
@ -124,14 +166,380 @@ func TestAddClientCA(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{}
if err := AddClientCA(tt.args.cert)(got); err != nil {
ctx := &TLSOptionCtx{
Config: &tls.Config{},
}
if err := AddClientCA(tt.args.cert)(ctx); err != nil {
t.Errorf("AddClientCA() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("AddClientCA() = %v, want %v", got, tt.want)
if !reflect.DeepEqual(ctx.Config, tt.want) {
t.Errorf("AddClientCA() = %v, want %v", ctx.Config, tt.want)
}
})
}
}
func TestAddRootsToRootCAs(t *testing.T) {
ca := startCATestServer()
defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Fatal(err)
}
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
cert := parseCertificate(string(root))
pool := x509.NewCertPool()
pool.AddCert(cert)
type args struct {
client *Client
config *tls.Config
}
tests := []struct {
name string
args args
want *tls.Config
wantErr bool
}{
{"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false},
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &TLSOptionCtx{
Client: tt.args.client,
Config: tt.args.config,
}
if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr {
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(ctx.Config, tt.want) {
t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want)
}
})
}
}
func TestAddRootsToClientCAs(t *testing.T) {
ca := startCATestServer()
defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Fatal(err)
}
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
cert := parseCertificate(string(root))
pool := x509.NewCertPool()
pool.AddCert(cert)
type args struct {
client *Client
config *tls.Config
}
tests := []struct {
name string
args args
want *tls.Config
wantErr bool
}{
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false},
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &TLSOptionCtx{
Client: tt.args.client,
Config: tt.args.config,
}
if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr {
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(ctx.Config, tt.want) {
t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want)
}
})
}
}
func TestAddFederationToRootCAs(t *testing.T) {
ca := startCATestServer()
defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Fatal(err)
}
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil {
t.Fatal(err)
}
crt1 := parseCertificate(string(root))
crt2 := parseCertificate(string(federated))
pool := x509.NewCertPool()
pool.AddCert(crt1)
pool.AddCert(crt2)
type args struct {
client *Client
config *tls.Config
}
tests := []struct {
name string
args args
want *tls.Config
wantErr bool
}{
{"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false},
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &TLSOptionCtx{
Client: tt.args.client,
Config: tt.args.config,
}
if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr {
t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(ctx.Config, tt.want) {
// Federated roots are randomly sorted
if !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) || ctx.Config.ClientCAs != nil {
t.Errorf("AddFederationToRootCAs() = %v, want %v", ctx.Config, tt.want)
}
}
})
}
}
func TestAddFederationToClientCAs(t *testing.T) {
ca := startCATestServer()
defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Fatal(err)
}
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil {
t.Fatal(err)
}
crt1 := parseCertificate(string(root))
crt2 := parseCertificate(string(federated))
pool := x509.NewCertPool()
pool.AddCert(crt1)
pool.AddCert(crt2)
type args struct {
client *Client
config *tls.Config
}
tests := []struct {
name string
args args
want *tls.Config
wantErr bool
}{
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false},
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &TLSOptionCtx{
Client: tt.args.client,
Config: tt.args.config,
}
if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr {
t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(ctx.Config, tt.want) {
// Federated roots are randomly sorted
if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || ctx.Config.RootCAs != nil {
t.Errorf("AddFederationToClientCAs() = %v, want %v", ctx.Config, tt.want)
}
}
})
}
}
func TestAddRootsToCAs(t *testing.T) {
ca := startCATestServer()
defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Fatal(err)
}
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
cert := parseCertificate(string(root))
pool := x509.NewCertPool()
pool.AddCert(cert)
type args struct {
client *Client
config *tls.Config
}
tests := []struct {
name string
args args
want *tls.Config
wantErr bool
}{
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &TLSOptionCtx{
Client: tt.args.client,
Config: tt.args.config,
}
if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr {
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(ctx.Config, tt.want) {
t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want)
}
})
}
}
func TestAddFederationToCAs(t *testing.T) {
ca := startCATestServer()
defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Fatal(err)
}
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil {
t.Fatal(err)
}
crt1 := parseCertificate(string(root))
crt2 := parseCertificate(string(federated))
pool := x509.NewCertPool()
pool.AddCert(crt1)
pool.AddCert(crt2)
type args struct {
client *Client
config *tls.Config
}
tests := []struct {
name string
args args
want *tls.Config
wantErr bool
}{
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &TLSOptionCtx{
Client: tt.args.client,
Config: tt.args.config,
}
if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr {
t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(ctx.Config, tt.want) {
// Federated roots are randomly sorted
if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) {
t.Errorf("AddFederationToCAs() = %v, want %v", ctx.Config, tt.want)
}
}
})
}
}
func equalPools(a, b *x509.CertPool) bool {
subjects := a.Subjects()
sA := make([]string, len(subjects))
for i := range subjects {
sA[i] = string(subjects[i])
}
subjects = b.Subjects()
sB := make([]string, len(subjects))
for i := range subjects {
sB[i] = string(subjects[i])
}
sort.Sort(sort.StringSlice(sA))
sort.Sort(sort.StringSlice(sB))
return reflect.DeepEqual(sA, sB)
}

View file

@ -20,7 +20,7 @@ import (
"github.com/smallstep/certificates/authority"
"github.com/smallstep/cli/crypto/randutil"
stepJOSE "github.com/smallstep/cli/jose"
"gopkg.in/square/go-jose.v2"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
@ -242,16 +242,15 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
}
func TestClient_GetServerTLSConfig_renew(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
reset := setMinCertDuration(1 * time.Second)
defer reset()
// Start CA
ca := startCATestServer()
defer ca.Close()
clientDomain := "test.domain"
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second)
// Start mTLS server
ctx, cancel := context.WithCancel(context.Background())
@ -274,13 +273,13 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
defer srvTLS.Close()
// Transport
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second)
tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.Transport() error = %v", err)
}
// Transport with tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
@ -367,9 +366,9 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
t.Errorf("number of fingerprints unexpected, got %d, want 2", l)
}
// Wait for renewal 40s == 1m-1m/3
log.Printf("Sleeping for %s ...\n", 40*time.Second)
time.Sleep(40 * time.Second)
// Wait for renewal
log.Printf("Sleeping for %s ...\n", 5*time.Second)
time.Sleep(5 * time.Second)
for _, tt := range tests {
t.Run("renewed "+tt.name, func(t *testing.T) {