forked from TrueCloudLab/certificates
Merge pull request #22 from smallstep/mariano/multiroot
Multiple roots and federation
This commit is contained in:
commit
d0e0217955
30 changed files with 1782 additions and 136 deletions
54
api/api.go
54
api/api.go
|
@ -22,9 +22,11 @@ type Authority interface {
|
|||
GetTLSOptions() *tlsutil.TLSOptions
|
||||
Root(shasum string) (*x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
|
||||
Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
|
||||
GetEncryptedKey(kid string) (string, error)
|
||||
GetRoots() (federation []*x509.Certificate, err error)
|
||||
GetFederation() ([]*x509.Certificate, error)
|
||||
}
|
||||
|
||||
// Certificate wraps a *x509.Certificate and adds the json.Marshaler interface.
|
||||
|
@ -186,6 +188,16 @@ type SignResponse struct {
|
|||
TLS *tls.ConnectionState `json:"-"`
|
||||
}
|
||||
|
||||
// RootsResponse is the response object of the roots request.
|
||||
type RootsResponse struct {
|
||||
Certificates []Certificate `json:"crts"`
|
||||
}
|
||||
|
||||
// FederationResponse is the response object of the federation request.
|
||||
type FederationResponse struct {
|
||||
Certificates []Certificate `json:"crts"`
|
||||
}
|
||||
|
||||
// caHandler is the type used to implement the different CA HTTP endpoints.
|
||||
type caHandler struct {
|
||||
Authority Authority
|
||||
|
@ -205,6 +217,8 @@ func (h *caHandler) Route(r Router) {
|
|||
r.MethodFunc("POST", "/renew", h.Renew)
|
||||
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
||||
r.MethodFunc("GET", "/roots", h.Roots)
|
||||
r.MethodFunc("GET", "/federation", h.Federation)
|
||||
// For compatibility with old code:
|
||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||
}
|
||||
|
@ -320,6 +334,44 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
|||
JSON(w, &ProvisionerKeyResponse{key})
|
||||
}
|
||||
|
||||
// Roots returns all the root certificates for the CA.
|
||||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := h.Authority.GetRoots()
|
||||
if err != nil {
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
|
||||
certs := make([]Certificate, len(roots))
|
||||
for i := range roots {
|
||||
certs[i] = Certificate{roots[i]}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
JSON(w, &RootsResponse{
|
||||
Certificates: certs,
|
||||
})
|
||||
}
|
||||
|
||||
// Federation returns all the public certificates in the federation.
|
||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := h.Authority.GetFederation()
|
||||
if err != nil {
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
|
||||
certs := make([]Certificate, len(federated))
|
||||
for i := range federated {
|
||||
certs[i] = Certificate{federated[i]}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
JSON(w, &FederationResponse{
|
||||
Certificates: certs,
|
||||
})
|
||||
}
|
||||
|
||||
func parseCursor(r *http.Request) (cursor string, limit int, err error) {
|
||||
q := r.URL.Query()
|
||||
cursor = q.Get("cursor")
|
||||
|
|
108
api/api_test.go
108
api/api_test.go
|
@ -392,6 +392,8 @@ type mockAuthority struct {
|
|||
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
|
||||
getEncryptedKey func(kid string) (string, error)
|
||||
getRoots func() ([]*x509.Certificate, error)
|
||||
getFederation func() ([]*x509.Certificate, error)
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) {
|
||||
|
@ -443,6 +445,20 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
|
|||
return m.ret1.(string), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) {
|
||||
if m.getRoots != nil {
|
||||
return m.getRoots()
|
||||
}
|
||||
return m.ret1.([]*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) {
|
||||
if m.getFederation != nil {
|
||||
return m.getFederation()
|
||||
}
|
||||
return m.ret1.([]*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func Test_caHandler_Route(t *testing.T) {
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
|
@ -812,3 +828,95 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Roots(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
tls *tls.ConnectionState
|
||||
cert *x509.Certificate
|
||||
root *x509.Certificate
|
||||
err error
|
||||
statusCode int
|
||||
}{
|
||||
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||
{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||
}
|
||||
|
||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
h.Roots(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.Roots unexpected error = %v", err)
|
||||
}
|
||||
if tt.statusCode < http.StatusBadRequest {
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||||
t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Federation(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
tls *tls.ConnectionState
|
||||
cert *x509.Certificate
|
||||
root *x509.Certificate
|
||||
err error
|
||||
statusCode int
|
||||
}{
|
||||
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||
{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||
}
|
||||
|
||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
h.Federation(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.Federation unexpected error = %v", err)
|
||||
}
|
||||
if tt.statusCode < http.StatusBadRequest {
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||||
t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ package authority
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
realx509 "crypto/x509"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority"
|
|||
// Authority implements the Certificate Authority internal interface.
|
||||
type Authority struct {
|
||||
config *Config
|
||||
rootX509Crt *realx509.Certificate
|
||||
rootX509Certs []*x509.Certificate
|
||||
intermediateIdentity *x509util.Identity
|
||||
validateOnce bool
|
||||
certificates *sync.Map
|
||||
|
@ -79,15 +79,29 @@ func (a *Authority) init() error {
|
|||
}
|
||||
|
||||
var err error
|
||||
// First load the root using our modified pem/x509 package.
|
||||
a.rootX509Crt, err = pemutil.ReadCertificate(a.config.Root)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
// Load the root certificates and add them to the certificate store
|
||||
a.rootX509Certs = make([]*x509.Certificate, len(a.config.Root))
|
||||
for i, path := range a.config.Root {
|
||||
crt, err := pemutil.ReadCertificate(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Add root certificate to the certificate map
|
||||
sum := sha256.Sum256(crt.Raw)
|
||||
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
|
||||
a.rootX509Certs[i] = crt
|
||||
}
|
||||
|
||||
// Add root certificate to the certificate map
|
||||
sum := sha256.Sum256(a.rootX509Crt.Raw)
|
||||
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
|
||||
// Add federated roots
|
||||
for _, path := range a.config.FederatedRoots {
|
||||
crt, err := pemutil.ReadCertificate(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sum := sha256.Sum256(crt.Raw)
|
||||
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
|
||||
}
|
||||
|
||||
// Decrypt and load intermediate public / private key pair.
|
||||
if len(a.config.Password) > 0 {
|
||||
|
|
|
@ -38,7 +38,7 @@ func testAuthority(t *testing.T) *Authority {
|
|||
}
|
||||
c := &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.ca.smallstep.com"},
|
||||
|
@ -68,7 +68,7 @@ func TestAuthorityNew(t *testing.T) {
|
|||
"fail bad root": func(t *testing.T) *newTest {
|
||||
c, err := LoadConfiguration("../ca/testdata/ca.json")
|
||||
assert.FatalError(t, err)
|
||||
c.Root = "foo"
|
||||
c.Root = []string{"foo"}
|
||||
return &newTest{
|
||||
config: c,
|
||||
err: errors.New("open foo failed: no such file or directory"),
|
||||
|
@ -105,10 +105,10 @@ func TestAuthorityNew(t *testing.T) {
|
|||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
sum := sha256.Sum256(auth.rootX509Crt.Raw)
|
||||
sum := sha256.Sum256(auth.rootX509Certs[0].Raw)
|
||||
root, ok := auth.certificates.Load(hex.EncodeToString(sum[:]))
|
||||
assert.Fatal(t, ok)
|
||||
assert.Equals(t, auth.rootX509Crt, root)
|
||||
assert.Equals(t, auth.rootX509Certs[0], root)
|
||||
|
||||
assert.True(t, auth.initOnce)
|
||||
assert.NotNil(t, auth.intermediateIdentity)
|
||||
|
|
|
@ -33,38 +33,10 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
type duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
// MarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *duration) UnmarshalJSON(data []byte) (err error) {
|
||||
var s string
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
}
|
||||
if d.Duration, err = time.ParseDuration(s); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Config represents the CA configuration and it's mapped to a JSON object.
|
||||
type Config struct {
|
||||
Root string `json:"root"`
|
||||
Root multiString `json:"root"`
|
||||
FederatedRoots []string `json:"federatedRoots"`
|
||||
IntermediateCert string `json:"crt"`
|
||||
IntermediateKey string `json:"key"`
|
||||
Address string `json:"address"`
|
||||
|
@ -145,7 +117,7 @@ func (c *Config) Validate() error {
|
|||
case c.Address == "":
|
||||
return errors.New("address cannot be empty")
|
||||
|
||||
case c.Root == "":
|
||||
case c.Root.HasEmpties():
|
||||
return errors.New("root cannot be empty")
|
||||
|
||||
case c.IntermediateCert == "":
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestConfigValidate(t *testing.T) {
|
|||
"empty-address": func(t *testing.T) ConfigValidateTest {
|
||||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
|
@ -54,7 +54,7 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1",
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
|
@ -81,7 +81,7 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
|
@ -94,7 +94,7 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
|
@ -107,7 +107,7 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
Password: "pass",
|
||||
|
@ -120,7 +120,7 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
|
@ -134,7 +134,7 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
|
@ -149,7 +149,7 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
|
@ -178,7 +178,7 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: "testdata/secrets/root_ca.crt",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
|
|
|
@ -17,7 +17,7 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
|||
|
||||
crt, ok := val.(*x509.Certificate)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("stored value is not a *cryto/x509.Certificate"),
|
||||
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
}
|
||||
return crt, nil
|
||||
|
@ -25,5 +25,39 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
|||
|
||||
// GetRootCertificate returns the server root certificate.
|
||||
func (a *Authority) GetRootCertificate() *x509.Certificate {
|
||||
return a.rootX509Crt
|
||||
return a.rootX509Certs[0]
|
||||
}
|
||||
|
||||
// GetRootCertificates returns the server root certificates.
|
||||
//
|
||||
// In the Authority interface we also have a similar method, GetRoots, at the
|
||||
// moment the functionality of these two methods are almost identical, but this
|
||||
// method is intended to be used internally by CA HTTP server to load the roots
|
||||
// that will be set in the tls.Config while GetRoots will be used by the
|
||||
// Authority interface and might have extra checks in the future.
|
||||
func (a *Authority) GetRootCertificates() []*x509.Certificate {
|
||||
return a.rootX509Certs
|
||||
}
|
||||
|
||||
// GetRoots returns all the root certificates for this CA.
|
||||
// This method implements the Authority interface.
|
||||
func (a *Authority) GetRoots() ([]*x509.Certificate, error) {
|
||||
return a.rootX509Certs, nil
|
||||
}
|
||||
|
||||
// GetFederation returns all the root certificates in the federation.
|
||||
// This method implements the Authority interface.
|
||||
func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) {
|
||||
a.certificates.Range(func(k, v interface{}) bool {
|
||||
crt, ok := v.(*x509.Certificate)
|
||||
if !ok {
|
||||
federation = nil
|
||||
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
return false
|
||||
}
|
||||
federation = append(federation, crt)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestRoot(t *testing.T) {
|
||||
|
@ -17,7 +20,7 @@ func TestRoot(t *testing.T) {
|
|||
err *apiError
|
||||
}{
|
||||
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
|
||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *cryto/x509.Certificate"), http.StatusInternalServerError, context{}}},
|
||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}},
|
||||
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
|
||||
}
|
||||
|
||||
|
@ -37,9 +40,116 @@ func TestRoot(t *testing.T) {
|
|||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, crt, a.rootX509Crt)
|
||||
assert.Equals(t, crt, a.rootX509Certs[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_GetRootCertificate(t *testing.T) {
|
||||
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want *x509.Certificate
|
||||
}{
|
||||
{"ok", cert},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := testAuthority(t)
|
||||
if got := a.GetRootCertificate(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authority.GetRootCertificate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_GetRootCertificates(t *testing.T) {
|
||||
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want []*x509.Certificate
|
||||
}{
|
||||
{"ok", []*x509.Certificate{cert}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := testAuthority(t)
|
||||
if got := a.GetRootCertificates(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authority.GetRootCertificates() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_GetRoots(t *testing.T) {
|
||||
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want []*x509.Certificate
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", []*x509.Certificate{cert}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
a := testAuthority(t)
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := a.GetRoots()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authority.GetRoots() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_GetFederation(t *testing.T) {
|
||||
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
wantFederation []*x509.Certificate
|
||||
wantErr bool
|
||||
fn func(a *Authority)
|
||||
}{
|
||||
{"ok", []*x509.Certificate{cert}, false, nil},
|
||||
{"fail", nil, true, func(a *Authority) {
|
||||
a.certificates.Store("foo", "bar")
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := testAuthority(t)
|
||||
if tt.fn != nil {
|
||||
tt.fn(a)
|
||||
}
|
||||
gotFederation, err := a.GetFederation()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotFederation, tt.wantFederation) {
|
||||
t.Errorf("Authority.GetFederation() = %v, want %v", gotFederation, tt.wantFederation)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
104
authority/types.go
Normal file
104
authority/types.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
// MarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *duration) UnmarshalJSON(data []byte) (err error) {
|
||||
var s string
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
}
|
||||
if d.Duration, err = time.ParseDuration(s); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// multiString represents a type that can be encoded/decoded in JSON as a single
|
||||
// string or an array of strings.
|
||||
type multiString []string
|
||||
|
||||
// First returns the first element of a multiString. It will return an empty
|
||||
// string if the multistring is empty.
|
||||
func (s multiString) First() string {
|
||||
if len(s) > 0 {
|
||||
return s[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// HasEmpties returns `true` if any string in the array is empty.
|
||||
func (s multiString) HasEmpties() bool {
|
||||
if len(s) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, ss := range s {
|
||||
if len(ss) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalJSON marshals the multistring as a string or a slice of strings . With
|
||||
// 0 elements it will return the empty string, with 1 element a regular string,
|
||||
// otherwise a slice of strings.
|
||||
func (s multiString) MarshalJSON() ([]byte, error) {
|
||||
switch len(s) {
|
||||
case 0:
|
||||
return []byte(`""`), nil
|
||||
case 1:
|
||||
return json.Marshal(s[0])
|
||||
default:
|
||||
return json.Marshal([]string(s))
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a string or a slice and sets it to the multiString.
|
||||
func (s *multiString) UnmarshalJSON(data []byte) error {
|
||||
if s == nil {
|
||||
return errors.New("multiString cannot be nil")
|
||||
}
|
||||
if len(data) == 0 {
|
||||
*s = nil
|
||||
return nil
|
||||
}
|
||||
// Parse string
|
||||
if data[0] == '"' {
|
||||
var str string
|
||||
if err := json.Unmarshal(data, &str); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
}
|
||||
*s = []string{str}
|
||||
return nil
|
||||
}
|
||||
// Parse array
|
||||
var ss []string
|
||||
if err := json.Unmarshal(data, &ss); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
}
|
||||
*s = ss
|
||||
return nil
|
||||
}
|
103
authority/types_test.go
Normal file
103
authority/types_test.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_multiString_First(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s multiString
|
||||
want string
|
||||
}{
|
||||
{"empty", multiString{}, ""},
|
||||
{"string", multiString{"one"}, "one"},
|
||||
{"slice", multiString{"one", "two"}, "one"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.s.First(); got != tt.want {
|
||||
t.Errorf("multiString.First() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_multiString_Empties(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s multiString
|
||||
want bool
|
||||
}{
|
||||
{"empty", multiString{}, true},
|
||||
{"string", multiString{"one"}, false},
|
||||
{"empty string", multiString{""}, true},
|
||||
{"slice", multiString{"one", "two"}, false},
|
||||
{"empty slice", multiString{"one", ""}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.s.HasEmpties(); got != tt.want {
|
||||
t.Errorf("multiString.Empties() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_multiString_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s multiString
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty", []string{}, []byte(`""`), false},
|
||||
{"string", []string{"a string"}, []byte(`"a string"`), false},
|
||||
{"slice", []string{"string one", "string two"}, []byte(`["string one","string two"]`), false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.s.MarshalJSON()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("multiString.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("multiString.MarshalJSON() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_multiString_UnmarshalJSON(t *testing.T) {
|
||||
|
||||
type args struct {
|
||||
data []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
s *multiString
|
||||
args args
|
||||
want *multiString
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty", new(multiString), args{[]byte{}}, new(multiString), false},
|
||||
{"empty string", new(multiString), args{[]byte(`""`)}, &multiString{""}, false},
|
||||
{"string", new(multiString), args{[]byte(`"a string"`)}, &multiString{"a string"}, false},
|
||||
{"slice", new(multiString), args{[]byte(`["string one","string two"]`)}, &multiString{"string one", "string two"}, false},
|
||||
{"error", new(multiString), args{[]byte(`["123",123]`)}, new(multiString), true},
|
||||
{"nil", nil, args{nil}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.s.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||
t.Errorf("multiString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(tt.s, tt.want) {
|
||||
t.Errorf("multiString.UnmarshalJSON() = %v, want %v", tt.s, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -87,6 +87,9 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs
|
||||
options = append(options, AddRootsToCAs())
|
||||
|
||||
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -130,6 +133,9 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Make sure the tlsConfig have all supported roots on RootCAs
|
||||
options = append(options, AddRootsToRootCAs())
|
||||
|
||||
transport, err := client.Transport(ctx, sign, pk, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -3,12 +3,15 @@ package ca
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
|
||||
|
@ -18,6 +21,24 @@ import (
|
|||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func newLocalListener() net.Listener {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
||||
panic(errors.Wrap(err, "failed to listen on a port"))
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func setMinCertDuration(d time.Duration) func() {
|
||||
tmp := minCertDuration
|
||||
minCertDuration = 1 * time.Second
|
||||
return func() {
|
||||
minCertDuration = tmp
|
||||
}
|
||||
}
|
||||
|
||||
func startCABootstrapServer() *httptest.Server {
|
||||
config, err := authority.LoadConfiguration("testdata/ca.json")
|
||||
if err != nil {
|
||||
|
@ -115,8 +136,10 @@ func TestBootstrap(t *testing.T) {
|
|||
if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) {
|
||||
t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint)
|
||||
}
|
||||
if !reflect.DeepEqual(got.certPool, tt.want.certPool) {
|
||||
t.Errorf("Bootstrap() certPool = %v, want %v", got.certPool, tt.want.certPool)
|
||||
gotTR := got.client.Transport.(*http.Transport)
|
||||
wantTR := tt.want.client.Transport.(*http.Transport)
|
||||
if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) {
|
||||
t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -267,3 +290,147 @@ func TestBootstrapClient(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapClientServerRotation(t *testing.T) {
|
||||
reset := setMinCertDuration(1 * time.Second)
|
||||
defer reset()
|
||||
|
||||
// Configuration with current root
|
||||
config, err := authority.LoadConfiguration("testdata/rotate-ca-0.json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get local address
|
||||
listener := newLocalListener()
|
||||
config.Address = listener.Addr().String()
|
||||
caURL := "https://" + listener.Addr().String()
|
||||
|
||||
// Start CA server
|
||||
ca, err := New(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
ca.srv.Serve(listener)
|
||||
}()
|
||||
defer ca.Stop()
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Create bootstrap server
|
||||
token := generateBootstrapToken(caURL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
|
||||
server, err := BootstrapServer(context.Background(), token, &http.Server{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write([]byte("ok"))
|
||||
}),
|
||||
}, RequireAndVerifyClientCert())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
listener = newLocalListener()
|
||||
srvURL := "https://" + listener.Addr().String()
|
||||
go func() {
|
||||
server.ServeTLS(listener, "", "")
|
||||
}()
|
||||
defer server.Close()
|
||||
|
||||
// Create bootstrap client
|
||||
token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
|
||||
client, err := BootstrapClient(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Errorf("BootstrapClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// doTest does a request that requires mTLS
|
||||
doTest := func(client *http.Client) error {
|
||||
// test with ca
|
||||
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "client.Post() failed")
|
||||
}
|
||||
var renew api.SignResponse
|
||||
if err := readJSON(resp.Body, &renew); err != nil {
|
||||
return errors.Wrap(err, "client.Post() error reading response")
|
||||
}
|
||||
if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil {
|
||||
return errors.New("client.Post() unexpected response found")
|
||||
}
|
||||
// test with bootstrap server
|
||||
resp, err = client.Get(srvURL)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "client.Get() error reading response")
|
||||
}
|
||||
if string(b) != "ok" {
|
||||
return errors.New("client.Get() unexpected response found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test with default root
|
||||
if err := doTest(client); err != nil {
|
||||
t.Errorf("Test with rotate-ca-0.json failed: %v", err)
|
||||
}
|
||||
|
||||
// wait for renew
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Reload with configuration with current and future root
|
||||
ca.opts.configFile = "testdata/rotate-ca-1.json"
|
||||
if err := doReload(ca); err != nil {
|
||||
t.Errorf("ca.Reload() error = %v", err)
|
||||
return
|
||||
}
|
||||
if err := doTest(client); err != nil {
|
||||
t.Errorf("Test with rotate-ca-1.json failed: %v", err)
|
||||
}
|
||||
|
||||
// wait for renew
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Reload with new and old root
|
||||
ca.opts.configFile = "testdata/rotate-ca-2.json"
|
||||
if err := doReload(ca); err != nil {
|
||||
t.Errorf("ca.Reload() error = %v", err)
|
||||
return
|
||||
}
|
||||
if err := doTest(client); err != nil {
|
||||
t.Errorf("Test with rotate-ca-2.json failed: %v", err)
|
||||
}
|
||||
|
||||
// wait for renew
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Reload with pnly the new root
|
||||
ca.opts.configFile = "testdata/rotate-ca-3.json"
|
||||
if err := doReload(ca); err != nil {
|
||||
t.Errorf("ca.Reload() error = %v", err)
|
||||
return
|
||||
}
|
||||
if err := doTest(client); err != nil {
|
||||
t.Errorf("Test with rotate-ca-3.json failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// doReload uses the reload implementation but overwrites the new address with
|
||||
// the one being used.
|
||||
func doReload(ca *CA) error {
|
||||
config, err := authority.LoadConfiguration(ca.opts.configFile)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error reloading ca")
|
||||
}
|
||||
|
||||
newCA, err := New(config, WithPassword(ca.opts.password), WithConfigFile(ca.opts.configFile))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error reloading ca")
|
||||
}
|
||||
// Use same address in new server
|
||||
newCA.srv.Addr = ca.srv.Addr
|
||||
return ca.srv.Reload(newCA.srv)
|
||||
}
|
||||
|
|
4
ca/ca.go
4
ca/ca.go
|
@ -176,7 +176,9 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
|
|||
}
|
||||
|
||||
certPool := x509.NewCertPool()
|
||||
certPool.AddCert(auth.GetRootCertificate())
|
||||
for _, crt := range auth.GetRootCertificates() {
|
||||
certPool.AddCert(crt)
|
||||
}
|
||||
|
||||
// GetCertificate will only be called if the client supplies SNI
|
||||
// information or if tlsConfig.Certificates is empty.
|
||||
|
|
50
ca/client.go
50
ca/client.go
|
@ -23,7 +23,6 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"golang.org/x/net/http2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
|
@ -239,7 +238,6 @@ func WithProvisionerLimit(limit int) ProvisionerOption {
|
|||
type Client struct {
|
||||
client *http.Client
|
||||
endpoint *url.URL
|
||||
certPool *x509.CertPool
|
||||
}
|
||||
|
||||
// NewClient creates a new Client with the given endpoint and options.
|
||||
|
@ -258,23 +256,11 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var cp *x509.CertPool
|
||||
switch tr := tr.(type) {
|
||||
case *http.Transport:
|
||||
if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil {
|
||||
cp = tr.TLSClientConfig.RootCAs
|
||||
}
|
||||
case *http2.Transport:
|
||||
if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil {
|
||||
cp = tr.TLSClientConfig.RootCAs
|
||||
}
|
||||
}
|
||||
return &Client{
|
||||
client: &http.Client{
|
||||
Transport: tr,
|
||||
},
|
||||
endpoint: u,
|
||||
certPool: cp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -413,6 +399,42 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error)
|
|||
return &key, nil
|
||||
}
|
||||
|
||||
// Roots performs the get roots request to the CA and returns the
|
||||
// api.RootsResponse struct.
|
||||
func (c *Client) Roots() (*api.RootsResponse, error) {
|
||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
|
||||
resp, err := c.client.Get(u.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readError(resp.Body)
|
||||
}
|
||||
var roots api.RootsResponse
|
||||
if err := readJSON(resp.Body, &roots); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
}
|
||||
return &roots, nil
|
||||
}
|
||||
|
||||
// Federation performs the get federation request to the CA and returns the
|
||||
// api.FederationResponse struct.
|
||||
func (c *Client) Federation() (*api.FederationResponse, error) {
|
||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
|
||||
resp, err := c.client.Get(u.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readError(resp.Body)
|
||||
}
|
||||
var federation api.FederationResponse
|
||||
if err := readJSON(resp.Body, &federation); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
}
|
||||
return &federation, nil
|
||||
}
|
||||
|
||||
// CreateSignRequest is a helper function that given an x509 OTT returns a
|
||||
// simple but secure sign request as well as the private key used.
|
||||
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {
|
||||
|
|
|
@ -512,6 +512,128 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClient_Roots(t *testing.T) {
|
||||
ok := &api.RootsResponse{
|
||||
Certificates: []api.Certificate{
|
||||
{Certificate: parseCertificate(rootPEM)},
|
||||
},
|
||||
}
|
||||
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", ok, 200, false},
|
||||
{"unauthorized", unauthorized, 401, true},
|
||||
{"empty request", badRequest, 403, true},
|
||||
{"nil request", badRequest, 403, true},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
defer srv.Close()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Errorf("NewClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
})
|
||||
|
||||
got, err := c.Roots()
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Printf("%+v", err)
|
||||
t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case err != nil:
|
||||
if got != nil {
|
||||
t.Errorf("Client.Roots() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Roots() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Federation(t *testing.T) {
|
||||
ok := &api.FederationResponse{
|
||||
Certificates: []api.Certificate{
|
||||
{Certificate: parseCertificate(rootPEM)},
|
||||
},
|
||||
}
|
||||
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", ok, 200, false},
|
||||
{"unauthorized", unauthorized, 401, true},
|
||||
{"empty request", badRequest, 403, true},
|
||||
{"nil request", badRequest, 403, true},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
defer srv.Close()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Errorf("NewClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
})
|
||||
|
||||
got, err := c.Federation()
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Printf("%+v", err)
|
||||
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case err != nil:
|
||||
if got != nil {
|
||||
t.Errorf("Client.Federation() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Federation() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseEndpoint(t *testing.T) {
|
||||
expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"}
|
||||
expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"}
|
||||
|
|
|
@ -14,6 +14,8 @@ import (
|
|||
// certificate.
|
||||
type RenewFunc func() (*tls.Certificate, error)
|
||||
|
||||
var minCertDuration = time.Minute
|
||||
|
||||
// TLSRenewer automatically renews a tls certificate using a RenewFunc.
|
||||
type TLSRenewer struct {
|
||||
sync.RWMutex
|
||||
|
@ -58,8 +60,8 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
|
|||
}
|
||||
|
||||
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore)
|
||||
if period < time.Minute {
|
||||
return nil, errors.Errorf("period must be greater than or equal to 1 Minute, but got %v.", period)
|
||||
if period < minCertDuration {
|
||||
return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period)
|
||||
}
|
||||
// By default we will try to renew the cert before 2/3 of the validity
|
||||
// period have expired.
|
||||
|
|
4
ca/testdata/ca.json
vendored
4
ca/testdata/ca.json
vendored
|
@ -1,5 +1,6 @@
|
|||
{
|
||||
"root": "../ca/testdata/secrets/root_ca.crt",
|
||||
"federatedRoots": ["../ca/testdata/secrets/federated_ca.crt"],
|
||||
"crt": "../ca/testdata/secrets/intermediate_ca.crt",
|
||||
"key": "../ca/testdata/secrets/intermediate_ca_key",
|
||||
"password": "password",
|
||||
|
@ -17,7 +18,6 @@
|
|||
]
|
||||
},
|
||||
"authority": {
|
||||
"minCertDuration": "1m",
|
||||
"provisioners": [
|
||||
{
|
||||
"name": "max",
|
||||
|
@ -72,7 +72,7 @@
|
|||
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
|
||||
},
|
||||
"claims": {
|
||||
"minTLSCertDuration": "30s"
|
||||
"minTLSCertDuration": "1s"
|
||||
}
|
||||
}, {
|
||||
"name": "mariano",
|
||||
|
|
46
ca/testdata/rotate-ca-0.json
vendored
Normal file
46
ca/testdata/rotate-ca-0.json
vendored
Normal file
|
@ -0,0 +1,46 @@
|
|||
{
|
||||
"root": "testdata/secrets/root_ca.crt",
|
||||
"crt": "testdata/secrets/intermediate_ca.crt",
|
||||
"key": "testdata/secrets/intermediate_ca_key",
|
||||
"password": "password",
|
||||
"address": "127.0.0.1:0",
|
||||
"dnsNames": ["127.0.0.1"],
|
||||
"logger": {"format": "text"},
|
||||
"tls": {
|
||||
"minVersion": 1.2,
|
||||
"maxVersion": 1.2,
|
||||
"renegotiation": false,
|
||||
"cipherSuites": [
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
|
||||
]
|
||||
},
|
||||
"authority": {
|
||||
"provisioners": [
|
||||
{
|
||||
"name": "mariano",
|
||||
"type": "jwk",
|
||||
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ",
|
||||
"key": {
|
||||
"use": "sig",
|
||||
"kty": "EC",
|
||||
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
|
||||
"crv": "P-256",
|
||||
"alg": "ES256",
|
||||
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
|
||||
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
|
||||
},
|
||||
"claims": {
|
||||
"minTLSCertDuration": "1s",
|
||||
"defaultTLSCertDuration": "5s"
|
||||
}
|
||||
}
|
||||
],
|
||||
"template": {
|
||||
"country": "US",
|
||||
"locality": "San Francisco",
|
||||
"organization": "Smallstep"
|
||||
}
|
||||
}
|
||||
}
|
46
ca/testdata/rotate-ca-1.json
vendored
Normal file
46
ca/testdata/rotate-ca-1.json
vendored
Normal file
|
@ -0,0 +1,46 @@
|
|||
{
|
||||
"root": ["testdata/secrets/root_ca.crt", "testdata/rotated/root_ca.crt"],
|
||||
"crt": "testdata/secrets/intermediate_ca.crt",
|
||||
"key": "testdata/secrets/intermediate_ca_key",
|
||||
"password": "password",
|
||||
"address": "127.0.0.1:0",
|
||||
"dnsNames": ["127.0.0.1"],
|
||||
"logger": {"format": "text"},
|
||||
"tls": {
|
||||
"minVersion": 1.2,
|
||||
"maxVersion": 1.2,
|
||||
"renegotiation": false,
|
||||
"cipherSuites": [
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
|
||||
]
|
||||
},
|
||||
"authority": {
|
||||
"provisioners": [
|
||||
{
|
||||
"name": "mariano",
|
||||
"type": "jwk",
|
||||
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ",
|
||||
"key": {
|
||||
"use": "sig",
|
||||
"kty": "EC",
|
||||
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
|
||||
"crv": "P-256",
|
||||
"alg": "ES256",
|
||||
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
|
||||
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
|
||||
},
|
||||
"claims": {
|
||||
"minTLSCertDuration": "1s",
|
||||
"defaultTLSCertDuration": "5s"
|
||||
}
|
||||
}
|
||||
],
|
||||
"template": {
|
||||
"country": "US",
|
||||
"locality": "San Francisco",
|
||||
"organization": "Smallstep"
|
||||
}
|
||||
}
|
||||
}
|
46
ca/testdata/rotate-ca-2.json
vendored
Normal file
46
ca/testdata/rotate-ca-2.json
vendored
Normal file
|
@ -0,0 +1,46 @@
|
|||
{
|
||||
"root": ["testdata/rotated/root_ca.crt", "testdata/secrets/root_ca.crt"],
|
||||
"crt": "testdata/rotated/intermediate_ca.crt",
|
||||
"key": "testdata/rotated/intermediate_ca_key",
|
||||
"password": "asdf",
|
||||
"address": "127.0.0.1:0",
|
||||
"dnsNames": ["127.0.0.1"],
|
||||
"logger": {"format": "text"},
|
||||
"tls": {
|
||||
"minVersion": 1.2,
|
||||
"maxVersion": 1.2,
|
||||
"renegotiation": false,
|
||||
"cipherSuites": [
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
|
||||
]
|
||||
},
|
||||
"authority": {
|
||||
"provisioners": [
|
||||
{
|
||||
"name": "mariano",
|
||||
"type": "jwk",
|
||||
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ",
|
||||
"key": {
|
||||
"use": "sig",
|
||||
"kty": "EC",
|
||||
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
|
||||
"crv": "P-256",
|
||||
"alg": "ES256",
|
||||
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
|
||||
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
|
||||
},
|
||||
"claims": {
|
||||
"minTLSCertDuration": "1s",
|
||||
"defaultTLSCertDuration": "5s"
|
||||
}
|
||||
}
|
||||
],
|
||||
"template": {
|
||||
"country": "US",
|
||||
"locality": "San Francisco",
|
||||
"organization": "Smallstep"
|
||||
}
|
||||
}
|
||||
}
|
46
ca/testdata/rotate-ca-3.json
vendored
Normal file
46
ca/testdata/rotate-ca-3.json
vendored
Normal file
|
@ -0,0 +1,46 @@
|
|||
{
|
||||
"root": "testdata/rotated/root_ca.crt",
|
||||
"crt": "testdata/rotated/intermediate_ca.crt",
|
||||
"key": "testdata/rotated/intermediate_ca_key",
|
||||
"password": "asdf",
|
||||
"address": "127.0.0.1:0",
|
||||
"dnsNames": ["127.0.0.1"],
|
||||
"logger": {"format": "text"},
|
||||
"tls": {
|
||||
"minVersion": 1.2,
|
||||
"maxVersion": 1.2,
|
||||
"renegotiation": false,
|
||||
"cipherSuites": [
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
|
||||
]
|
||||
},
|
||||
"authority": {
|
||||
"provisioners": [
|
||||
{
|
||||
"name": "mariano",
|
||||
"type": "jwk",
|
||||
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ",
|
||||
"key": {
|
||||
"use": "sig",
|
||||
"kty": "EC",
|
||||
"kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
|
||||
"crv": "P-256",
|
||||
"alg": "ES256",
|
||||
"x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y",
|
||||
"y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA"
|
||||
},
|
||||
"claims": {
|
||||
"minTLSCertDuration": "1s",
|
||||
"defaultTLSCertDuration": "5s"
|
||||
}
|
||||
}
|
||||
],
|
||||
"template": {
|
||||
"country": "US",
|
||||
"locality": "San Francisco",
|
||||
"organization": "Smallstep"
|
||||
}
|
||||
}
|
||||
}
|
12
ca/testdata/rotated/intermediate_ca.crt
vendored
Normal file
12
ca/testdata/rotated/intermediate_ca.crt
vendored
Normal file
|
@ -0,0 +1,12 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIIBxTCCAWugAwIBAgIQLIY6MR/1fBRQY4ZTTsPAJjAKBggqhkjOPQQDAjAcMRow
|
||||
GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTAxMDcyMDExMzBaFw0yOTAx
|
||||
MDQyMDExMzBaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew
|
||||
WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARgtjL/KLNpdq81YYWaek1lrkPM/QF1
|
||||
m+ujwv5jya21fAXljdBLh6m2xco1GPfwPBbwUGlNOdEqE9Nq3Qx3ngPKo4GGMIGD
|
||||
MA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw
|
||||
EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUqixeZ/K1HW9N6SVw7ONya98S
|
||||
u8UwHwYDVR0jBBgwFoAUgIzlCLxh/RlwEany4JQHOorLAIEwCgYIKoZIzj0EAwID
|
||||
SAAwRQIgdGX6lxThrKlt3v+3HJZlaWdmoeQ3vYwpJb9uHExZdVYCIQDCxsdI8EnB
|
||||
bxjnJscbT4zvqVsq6AmycdbFwgy8RIeVzg==
|
||||
-----END CERTIFICATE-----
|
8
ca/testdata/rotated/intermediate_ca_key
vendored
Normal file
8
ca/testdata/rotated/intermediate_ca_key
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
-----BEGIN EC PRIVATE KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: AES-256-CBC,7dcc0a8c1d73c8d438184e0928875329
|
||||
|
||||
r6yrQrHg6zBZRSjQpe8RzyQALEfiT3/8lMvvPu3BX6yign5skMfCVMXZhzbmAwmR
|
||||
BJBIX+5hkudR2VN+hrsOyuU7FvIk4gx2c8buIlFObfYXIml0mpuThfm52ciAtOTE
|
||||
S0hkfYvPcOAjzaDZ+8Po/mYhkODgyvijogn4ioTF/Ss=
|
||||
-----END EC PRIVATE KEY-----
|
11
ca/testdata/rotated/root_ca.crt
vendored
Normal file
11
ca/testdata/rotated/root_ca.crt
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa
|
||||
MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw
|
||||
MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG
|
||||
SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL
|
||||
jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB
|
||||
/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR
|
||||
qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z
|
||||
YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc
|
||||
Sw==
|
||||
-----END CERTIFICATE-----
|
8
ca/testdata/rotated/root_ca_key
vendored
Normal file
8
ca/testdata/rotated/root_ca_key
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
-----BEGIN EC PRIVATE KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: AES-256-CBC,8ce79d28601b9809905ef7c362a20749
|
||||
|
||||
H+pTTL3B5fLYycgHLxFOW0fZsayr7Y+BW8THKf12h8dk0/eOE1wNoX2TuMtpbZgO
|
||||
lMJdFPL+SAPCCmuZOZIcQDejRHVcYBq1wvrrnw/yfVawXC4xze+J4Y+q0J2WY+rM
|
||||
xcLGlEOIRZkvdDVGmSitEZBl0Ibk0p9tG++7QGqAvnk=
|
||||
-----END EC PRIVATE KEY-----
|
11
ca/testdata/secrets/federated_ca.crt
vendored
Normal file
11
ca/testdata/secrets/federated_ca.crt
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa
|
||||
MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw
|
||||
MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG
|
||||
SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL
|
||||
jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB
|
||||
/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR
|
||||
qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z
|
||||
YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc
|
||||
Sw==
|
||||
-----END CERTIFICATE-----
|
23
ca/tls.go
23
ca/tls.go
|
@ -41,7 +41,8 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
}
|
||||
|
||||
// Apply options if given
|
||||
if err := setTLSOptions(tlsConfig, options); err != nil {
|
||||
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
||||
if err := tlsCtx.apply(options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -50,7 +51,10 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
|
||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||
|
||||
// Update client transport
|
||||
c.client.Transport = tr
|
||||
|
||||
// Start renewer
|
||||
renewer.RunContext(ctx)
|
||||
|
@ -87,7 +91,8 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
}
|
||||
|
||||
// Apply options if given
|
||||
if err := setTLSOptions(tlsConfig, options); err != nil {
|
||||
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
||||
if err := tlsCtx.apply(options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -96,7 +101,10 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
|
||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||
|
||||
// Update client transport
|
||||
c.client.Transport = tr
|
||||
|
||||
// Start renewer
|
||||
renewer.RunContext(ctx)
|
||||
|
@ -238,8 +246,13 @@ func getPEM(i interface{}) ([]byte, error) {
|
|||
return pem.EncodeToMemory(block), nil
|
||||
}
|
||||
|
||||
func getRenewFunc(client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
|
||||
func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
|
||||
return func() (*tls.Certificate, error) {
|
||||
// Get updated list of roots
|
||||
if err := ctx.applyRenew(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Get new certificate
|
||||
sign, err := client.Renew(tr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -6,13 +6,35 @@ import (
|
|||
)
|
||||
|
||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||
type TLSOption func(c *tls.Config) error
|
||||
type TLSOption func(ctx *TLSOptionCtx) error
|
||||
|
||||
// setTLSOptions takes one or more option function and applies them in order to
|
||||
// a tls.Config.
|
||||
func setTLSOptions(c *tls.Config, options []TLSOption) error {
|
||||
for _, opt := range options {
|
||||
if err := opt(c); err != nil {
|
||||
// TLSOptionCtx is the context modified on TLSOption methods.
|
||||
type TLSOptionCtx struct {
|
||||
Client *Client
|
||||
Config *tls.Config
|
||||
OnRenewFunc []TLSOption
|
||||
}
|
||||
|
||||
// newTLSOptionCtx creates the TLSOption context.
|
||||
func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx {
|
||||
return &TLSOptionCtx{
|
||||
Client: c,
|
||||
Config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
||||
for _, fn := range options {
|
||||
if err := fn(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ctx *TLSOptionCtx) applyRenew() error {
|
||||
for _, fn := range ctx.OnRenewFunc {
|
||||
if err := fn(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -22,8 +44,8 @@ func setTLSOptions(c *tls.Config, options []TLSOption) error {
|
|||
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
||||
// a valid TLS client certificate. This is the default option for mTLS servers.
|
||||
func RequireAndVerifyClientCert() TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
c.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
ctx.Config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -31,8 +53,8 @@ func RequireAndVerifyClientCert() TLSOption {
|
|||
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
||||
// a TLS client certificate if it is provided. It does not requires a certificate.
|
||||
func VerifyClientCertIfGiven() TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
c.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
ctx.Config.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -41,11 +63,11 @@ func VerifyClientCertIfGiven() TLSOption {
|
|||
// defines the set of root certificate authorities that clients use when
|
||||
// verifying server certificates.
|
||||
func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
if c.RootCAs == nil {
|
||||
c.RootCAs = x509.NewCertPool()
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
if ctx.Config.RootCAs == nil {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
c.RootCAs.AddCert(cert)
|
||||
ctx.Config.RootCAs.AddCert(cert)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -54,11 +76,163 @@ func AddRootCA(cert *x509.Certificate) TLSOption {
|
|||
// defines the set of root certificate authorities that servers use if required
|
||||
// to verify a client certificate by the policy in ClientAuth.
|
||||
func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
if c.ClientCAs == nil {
|
||||
c.ClientCAs = x509.NewCertPool()
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
if ctx.Config.ClientCAs == nil {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
c.ClientCAs.AddCert(cert)
|
||||
ctx.Config.ClientCAs.AddCert(cert)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddRootsToRootCAs does a roots request and adds to the tls.Config RootCAs all
|
||||
// the certificates in the response. RootCAs defines the set of root certificate
|
||||
// authorities that clients use when verifying server certificates.
|
||||
//
|
||||
// BootstrapServer and BootstrapClient methods include this option by default.
|
||||
func AddRootsToRootCAs() TLSOption {
|
||||
fn := func(ctx *TLSOptionCtx) error {
|
||||
certs, err := ctx.Client.Roots()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.Config.RootCAs == nil {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||
return fn(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs
|
||||
// all the certificates in the response. ClientCAs defines the set of root
|
||||
// certificate authorities that servers use if required to verify a client
|
||||
// certificate by the policy in ClientAuth.
|
||||
//
|
||||
// BootstrapServer method includes this option by default.
|
||||
func AddRootsToClientCAs() TLSOption {
|
||||
fn := func(ctx *TLSOptionCtx) error {
|
||||
certs, err := ctx.Client.Roots()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.Config.ClientCAs == nil {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||
return fn(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// AddFederationToRootCAs does a federation request and adds to the tls.Config
|
||||
// RootCAs all the certificates in the response. RootCAs defines the set of root
|
||||
// certificate authorities that clients use when verifying server certificates.
|
||||
func AddFederationToRootCAs() TLSOption {
|
||||
fn := func(ctx *TLSOptionCtx) error {
|
||||
certs, err := ctx.Client.Federation()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.Config.RootCAs == nil {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||
return fn(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// AddFederationToClientCAs does a federation request and adds to the tls.Config
|
||||
// ClientCAs all the certificates in the response. ClientCAs defines the set of
|
||||
// root certificate authorities that servers use if required to verify a client
|
||||
// certificate by the policy in ClientAuth.
|
||||
func AddFederationToClientCAs() TLSOption {
|
||||
fn := func(ctx *TLSOptionCtx) error {
|
||||
certs, err := ctx.Client.Federation()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.Config.ClientCAs == nil {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||
return fn(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// AddRootsToCAs does a roots request and adds the resulting certs to the
|
||||
// tls.Config RootCAs and ClientCAs. Combines the functionality of
|
||||
// AddRootsToRootCAs and AddRootsToClientCAs.
|
||||
func AddRootsToCAs() TLSOption {
|
||||
fn := func(ctx *TLSOptionCtx) error {
|
||||
certs, err := ctx.Client.Roots()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.Config.ClientCAs == nil {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
if ctx.Config.RootCAs == nil {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||
return fn(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// AddFederationToCAs does a federation request and adds the resulting certs to the
|
||||
// tls.Config RootCAs and ClientCAs. Combines the functionality of
|
||||
// AddFederationToRootCAs and AddFederationToClientCAs.
|
||||
func AddFederationToCAs() TLSOption {
|
||||
fn := func(ctx *TLSOptionCtx) error {
|
||||
certs, err := ctx.Client.Federation()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.Config.ClientCAs == nil {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
if ctx.Config.RootCAs == nil {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||
return fn(ctx)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,33 +4,69 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_setTLSOptions(t *testing.T) {
|
||||
func Test_newTLSOptionCtx(t *testing.T) {
|
||||
client, err := NewClient("https://ca.smallstep.com", WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
|
||||
type args struct {
|
||||
c *Client
|
||||
config *tls.Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *TLSOptionCtx
|
||||
}{
|
||||
{"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSOptionCtx_apply(t *testing.T) {
|
||||
fail := func() TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
return fmt.Errorf("an error")
|
||||
}
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
Config *tls.Config
|
||||
}
|
||||
type args struct {
|
||||
c *tls.Config
|
||||
options []TLSOption
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{&tls.Config{}, []TLSOption{RequireAndVerifyClientCert()}}, false},
|
||||
{"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false},
|
||||
{"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
|
||||
{"ok", fields{&tls.Config{}}, args{[]TLSOption{RequireAndVerifyClientCert()}}, false},
|
||||
{"ok", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven()}}, false},
|
||||
{"fail", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
||||
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||
ctx := &TLSOptionCtx{
|
||||
Config: tt.fields.Config,
|
||||
}
|
||||
if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr {
|
||||
t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -45,13 +81,15 @@ func TestRequireAndVerifyClientCert(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := RequireAndVerifyClientCert()(got); err != nil {
|
||||
ctx := &TLSOptionCtx{
|
||||
Config: &tls.Config{},
|
||||
}
|
||||
if err := RequireAndVerifyClientCert()(ctx); err != nil {
|
||||
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("RequireAndVerifyClientCert() = %v, want %v", got, tt.want)
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
t.Errorf("RequireAndVerifyClientCert() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -66,13 +104,15 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := VerifyClientCertIfGiven()(got); err != nil {
|
||||
ctx := &TLSOptionCtx{
|
||||
Config: &tls.Config{},
|
||||
}
|
||||
if err := VerifyClientCertIfGiven()(ctx); err != nil {
|
||||
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("VerifyClientCertIfGiven() = %v, want %v", got, tt.want)
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
t.Errorf("VerifyClientCertIfGiven() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -95,13 +135,15 @@ func TestAddRootCA(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := AddRootCA(tt.args.cert)(got); err != nil {
|
||||
ctx := &TLSOptionCtx{
|
||||
Config: &tls.Config{},
|
||||
}
|
||||
if err := AddRootCA(tt.args.cert)(ctx); err != nil {
|
||||
t.Errorf("AddRootCA() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("AddRootCA() = %v, want %v", got, tt.want)
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
t.Errorf("AddRootCA() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -124,14 +166,380 @@ func TestAddClientCA(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := AddClientCA(tt.args.cert)(got); err != nil {
|
||||
ctx := &TLSOptionCtx{
|
||||
Config: &tls.Config{},
|
||||
}
|
||||
if err := AddClientCA(tt.args.cert)(ctx); err != nil {
|
||||
t.Errorf("AddClientCA() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("AddClientCA() = %v, want %v", got, tt.want)
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
t.Errorf("AddClientCA() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddRootsToRootCAs(t *testing.T) {
|
||||
ca := startCATestServer()
|
||||
defer ca.Close()
|
||||
|
||||
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cert := parseCertificate(string(root))
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
type args struct {
|
||||
client *Client
|
||||
config *tls.Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tls.Config
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false},
|
||||
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &TLSOptionCtx{
|
||||
Client: tt.args.client,
|
||||
Config: tt.args.config,
|
||||
}
|
||||
if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
||||
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddRootsToClientCAs(t *testing.T) {
|
||||
ca := startCATestServer()
|
||||
defer ca.Close()
|
||||
|
||||
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cert := parseCertificate(string(root))
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
type args struct {
|
||||
client *Client
|
||||
config *tls.Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tls.Config
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false},
|
||||
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &TLSOptionCtx{
|
||||
Client: tt.args.client,
|
||||
Config: tt.args.config,
|
||||
}
|
||||
if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
||||
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFederationToRootCAs(t *testing.T) {
|
||||
ca := startCATestServer()
|
||||
defer ca.Close()
|
||||
|
||||
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
crt1 := parseCertificate(string(root))
|
||||
crt2 := parseCertificate(string(federated))
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(crt1)
|
||||
pool.AddCert(crt2)
|
||||
|
||||
type args struct {
|
||||
client *Client
|
||||
config *tls.Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tls.Config
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false},
|
||||
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &TLSOptionCtx{
|
||||
Client: tt.args.client,
|
||||
Config: tt.args.config,
|
||||
}
|
||||
if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
||||
t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
// Federated roots are randomly sorted
|
||||
if !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) || ctx.Config.ClientCAs != nil {
|
||||
t.Errorf("AddFederationToRootCAs() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFederationToClientCAs(t *testing.T) {
|
||||
ca := startCATestServer()
|
||||
defer ca.Close()
|
||||
|
||||
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
crt1 := parseCertificate(string(root))
|
||||
crt2 := parseCertificate(string(federated))
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(crt1)
|
||||
pool.AddCert(crt2)
|
||||
|
||||
type args struct {
|
||||
client *Client
|
||||
config *tls.Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tls.Config
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false},
|
||||
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &TLSOptionCtx{
|
||||
Client: tt.args.client,
|
||||
Config: tt.args.config,
|
||||
}
|
||||
if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
||||
t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
// Federated roots are randomly sorted
|
||||
if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || ctx.Config.RootCAs != nil {
|
||||
t.Errorf("AddFederationToClientCAs() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddRootsToCAs(t *testing.T) {
|
||||
ca := startCATestServer()
|
||||
defer ca.Close()
|
||||
|
||||
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cert := parseCertificate(string(root))
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
type args struct {
|
||||
client *Client
|
||||
config *tls.Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tls.Config
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
|
||||
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &TLSOptionCtx{
|
||||
Client: tt.args.client,
|
||||
Config: tt.args.config,
|
||||
}
|
||||
if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr {
|
||||
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFederationToCAs(t *testing.T) {
|
||||
ca := startCATestServer()
|
||||
defer ca.Close()
|
||||
|
||||
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
crt1 := parseCertificate(string(root))
|
||||
crt2 := parseCertificate(string(federated))
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(crt1)
|
||||
pool.AddCert(crt2)
|
||||
|
||||
type args struct {
|
||||
client *Client
|
||||
config *tls.Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tls.Config
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
|
||||
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &TLSOptionCtx{
|
||||
Client: tt.args.client,
|
||||
Config: tt.args.config,
|
||||
}
|
||||
if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr {
|
||||
t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
// Federated roots are randomly sorted
|
||||
if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) {
|
||||
t.Errorf("AddFederationToCAs() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func equalPools(a, b *x509.CertPool) bool {
|
||||
subjects := a.Subjects()
|
||||
sA := make([]string, len(subjects))
|
||||
for i := range subjects {
|
||||
sA[i] = string(subjects[i])
|
||||
}
|
||||
subjects = b.Subjects()
|
||||
sB := make([]string, len(subjects))
|
||||
for i := range subjects {
|
||||
sB[i] = string(subjects[i])
|
||||
}
|
||||
sort.Sort(sort.StringSlice(sA))
|
||||
sort.Sort(sort.StringSlice(sB))
|
||||
return reflect.DeepEqual(sA, sB)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ import (
|
|||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
stepJOSE "github.com/smallstep/cli/jose"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
|
@ -242,16 +242,15 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
reset := setMinCertDuration(1 * time.Second)
|
||||
defer reset()
|
||||
|
||||
// Start CA
|
||||
ca := startCATestServer()
|
||||
defer ca.Close()
|
||||
|
||||
clientDomain := "test.domain"
|
||||
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
|
||||
client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second)
|
||||
|
||||
// Start mTLS server
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -274,13 +273,13 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
|||
defer srvTLS.Close()
|
||||
|
||||
// Transport
|
||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second)
|
||||
tr1, err := client.Transport(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.Transport() error = %v", err)
|
||||
}
|
||||
// Transport with tlsConfig
|
||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second)
|
||||
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
|
||||
|
@ -367,9 +366,9 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
|||
t.Errorf("number of fingerprints unexpected, got %d, want 2", l)
|
||||
}
|
||||
|
||||
// Wait for renewal 40s == 1m-1m/3
|
||||
log.Printf("Sleeping for %s ...\n", 40*time.Second)
|
||||
time.Sleep(40 * time.Second)
|
||||
// Wait for renewal
|
||||
log.Printf("Sleeping for %s ...\n", 5*time.Second)
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("renewed "+tt.name, func(t *testing.T) {
|
||||
|
|
Loading…
Reference in a new issue