567 lines
15 KiB
Go
567 lines
15 KiB
Go
package ca
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/hex"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/smallstep/certificates/api"
|
|
"github.com/smallstep/certificates/authority"
|
|
"github.com/smallstep/cli/crypto/randutil"
|
|
stepJOSE "github.com/smallstep/cli/jose"
|
|
jose "gopkg.in/square/go-jose.v2"
|
|
"gopkg.in/square/go-jose.v2/jwt"
|
|
)
|
|
|
|
func generateOTT(subject string) string {
|
|
now := time.Now()
|
|
jwk, err := stepJOSE.ParseKey("testdata/secrets/ott_mariano_priv.jwk", stepJOSE.WithPassword([]byte("password")))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID)
|
|
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
id, err := randutil.ASCII(64)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
cl := struct {
|
|
jwt.Claims
|
|
SANS []string `json:"sans"`
|
|
}{
|
|
Claims: jwt.Claims{
|
|
ID: id,
|
|
Subject: subject,
|
|
Issuer: "mariano",
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
|
Audience: []string{"https://127.0.0.1:0/sign"},
|
|
},
|
|
SANS: []string{subject},
|
|
}
|
|
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return raw
|
|
}
|
|
|
|
func startTestServer(tlsConfig *tls.Config, handler http.Handler) *httptest.Server {
|
|
srv := httptest.NewUnstartedServer(handler)
|
|
srv.TLS = tlsConfig
|
|
srv.StartTLS()
|
|
// Force the use of GetCertificate on IPs
|
|
srv.TLS.Certificates = nil
|
|
return srv
|
|
}
|
|
|
|
func startCATestServer() *httptest.Server {
|
|
config, err := authority.LoadConfiguration("testdata/ca.json")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
ca, err := New(config)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
// Use a httptest.Server instead
|
|
return startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
|
|
}
|
|
|
|
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {
|
|
srv := startCATestServer()
|
|
defer srv.Close()
|
|
return signDuration(srv, domain, 0)
|
|
}
|
|
|
|
func signDuration(srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) {
|
|
req, pk, err := CreateSignRequest(generateOTT(domain))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if duration > 0 {
|
|
req.NotBefore = api.NewTimeDuration(time.Now())
|
|
req.NotAfter = api.NewTimeDuration(req.NotBefore.Time().Add(duration))
|
|
}
|
|
|
|
client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
sr, err := client.Sign(req)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return client, sr, pk
|
|
}
|
|
|
|
func serverHandler(t *testing.T, clientDomain string) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
if req.RequestURI != "/no-cert" {
|
|
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
|
w.Write([]byte("fail"))
|
|
t.Error("http.Request.TLS does not have peer certificates")
|
|
return
|
|
}
|
|
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
|
w.Write([]byte("fail"))
|
|
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
|
w.Write([]byte("fail"))
|
|
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
|
return
|
|
}
|
|
|
|
// Add serial number to check rotation
|
|
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
|
|
w.Header().Set("x-fingerprint", hex.EncodeToString(sum[:]))
|
|
}
|
|
|
|
w.Write([]byte("ok"))
|
|
})
|
|
}
|
|
|
|
func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|
clientDomain := "test.domain"
|
|
client, sr, pk := sign("127.0.0.1")
|
|
|
|
// Create mTLS server
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
|
|
if err != nil {
|
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
|
}
|
|
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
|
defer srvMTLS.Close()
|
|
|
|
// Create TLS server
|
|
ctx, cancel = context.WithCancel(context.Background())
|
|
defer cancel()
|
|
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
|
|
if err != nil {
|
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
|
}
|
|
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
|
defer srvTLS.Close()
|
|
|
|
tests := []struct {
|
|
name string
|
|
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
|
wantErr map[string]bool
|
|
}{
|
|
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
tr, err := client.Transport(context.Background(), sr, pk)
|
|
if err != nil {
|
|
t.Errorf("Client.Transport() error = %v", err)
|
|
return nil
|
|
}
|
|
return &http.Client{
|
|
Transport: tr,
|
|
}
|
|
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
|
|
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
|
if err != nil {
|
|
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
|
return nil
|
|
}
|
|
tr, err := getDefaultTransport(tlsConfig)
|
|
if err != nil {
|
|
t.Errorf("getDefaultTransport() error = %v", err)
|
|
return nil
|
|
}
|
|
return &http.Client{
|
|
Transport: tr,
|
|
}
|
|
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
|
|
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
root, err := RootCertificate(sr)
|
|
if err != nil {
|
|
t.Errorf("RootCertificate() error = %v", err)
|
|
return nil
|
|
}
|
|
tlsConfig := getDefaultTLSConfig(sr)
|
|
tlsConfig.RootCAs = x509.NewCertPool()
|
|
tlsConfig.RootCAs.AddCert(root)
|
|
|
|
tr, err := getDefaultTransport(tlsConfig)
|
|
if err != nil {
|
|
t.Errorf("getDefaultTransport() error = %v", err)
|
|
return nil
|
|
}
|
|
return &http.Client{
|
|
Transport: tr,
|
|
}
|
|
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
|
|
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
return &http.Client{}
|
|
}, map[string]bool{srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true}},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
client, sr, pk := sign(clientDomain)
|
|
cli := tt.getClient(t, client, sr, pk)
|
|
if cli == nil {
|
|
return
|
|
}
|
|
for path, wantErr := range tt.wantErr {
|
|
t.Run(path, func(t *testing.T) {
|
|
resp, err := cli.Get(path)
|
|
if (err != nil) != wantErr {
|
|
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, wantErr)
|
|
return
|
|
}
|
|
if wantErr {
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
b, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
|
}
|
|
if !bytes.Equal(b, []byte("ok")) {
|
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
|
reset := setMinCertDuration(1 * time.Second)
|
|
defer reset()
|
|
|
|
// Start CA
|
|
ca := startCATestServer()
|
|
defer ca.Close()
|
|
|
|
clientDomain := "test.domain"
|
|
client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second)
|
|
|
|
// Start mTLS server
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
|
|
if err != nil {
|
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
|
}
|
|
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
|
defer srvMTLS.Close()
|
|
|
|
// Start TLS server
|
|
ctx, cancel = context.WithCancel(context.Background())
|
|
defer cancel()
|
|
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
|
|
if err != nil {
|
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
|
}
|
|
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
|
defer srvTLS.Close()
|
|
|
|
// Transport
|
|
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second)
|
|
tr1, err := client.Transport(context.Background(), sr, pk)
|
|
if err != nil {
|
|
t.Fatalf("Client.Transport() error = %v", err)
|
|
}
|
|
// Transport with tlsConfig
|
|
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second)
|
|
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
|
|
if err != nil {
|
|
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
|
|
}
|
|
tr2, err := getDefaultTransport(tlsConfig)
|
|
if err != nil {
|
|
t.Fatalf("getDefaultTransport() error = %v", err)
|
|
}
|
|
// No client cert
|
|
root, err := RootCertificate(sr)
|
|
if err != nil {
|
|
t.Fatalf("RootCertificate() error = %v", err)
|
|
}
|
|
tlsConfig = getDefaultTLSConfig(sr)
|
|
tlsConfig.RootCAs = x509.NewCertPool()
|
|
tlsConfig.RootCAs.AddCert(root)
|
|
tr3, err := getDefaultTransport(tlsConfig)
|
|
if err != nil {
|
|
t.Fatalf("getDefaultTransport() error = %v", err)
|
|
}
|
|
|
|
// Disable keep alives to force TLS handshake
|
|
tr1.DisableKeepAlives = true
|
|
tr2.DisableKeepAlives = true
|
|
tr3.DisableKeepAlives = true
|
|
|
|
tests := []struct {
|
|
name string
|
|
client *http.Client
|
|
wantErr map[string]bool
|
|
}{
|
|
{"with transport", &http.Client{Transport: tr1}, map[string]bool{
|
|
srvTLS.URL: false,
|
|
srvMTLS.URL: false,
|
|
}},
|
|
{"with tlsConfig", &http.Client{Transport: tr2}, map[string]bool{
|
|
srvTLS.URL: false,
|
|
srvMTLS.URL: false,
|
|
}},
|
|
{"with no ClientCert", &http.Client{Transport: tr3}, map[string]bool{
|
|
srvTLS.URL + "/no-cert": false,
|
|
srvMTLS.URL + "/no-cert": true,
|
|
}},
|
|
{"fail with default", &http.Client{}, map[string]bool{
|
|
srvTLS.URL + "/no-cert": true,
|
|
srvMTLS.URL + "/no-cert": true,
|
|
}},
|
|
}
|
|
|
|
// To count different cert fingerprints
|
|
fingerprints := map[string]struct{}{}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
for path, wantErr := range tt.wantErr {
|
|
t.Run(path, func(t *testing.T) {
|
|
resp, err := tt.client.Get(path)
|
|
if (err != nil) != wantErr {
|
|
t.Errorf("http.Client.Get() error = %v", err)
|
|
return
|
|
}
|
|
if wantErr {
|
|
return
|
|
}
|
|
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
|
|
fingerprints[fp] = struct{}{}
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
b, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Errorf("ioutil.RealAdd() error = %v", err)
|
|
return
|
|
}
|
|
if !bytes.Equal(b, []byte("ok")) {
|
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
if l := len(fingerprints); l != 2 {
|
|
t.Errorf("number of fingerprints unexpected, got %d, want 2", l)
|
|
}
|
|
|
|
// Wait for renewal
|
|
log.Printf("Sleeping for %s ...\n", 5*time.Second)
|
|
time.Sleep(5 * time.Second)
|
|
|
|
for _, tt := range tests {
|
|
t.Run("renewed "+tt.name, func(t *testing.T) {
|
|
for path, wantErr := range tt.wantErr {
|
|
t.Run(path, func(t *testing.T) {
|
|
resp, err := tt.client.Get(path)
|
|
if (err != nil) != wantErr {
|
|
t.Errorf("http.Client.Get() error = %v", err)
|
|
return
|
|
}
|
|
if wantErr {
|
|
return
|
|
}
|
|
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
|
|
fingerprints[fp] = struct{}{}
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
b, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Errorf("ioutil.RealAdd() error = %v", err)
|
|
return
|
|
}
|
|
if !bytes.Equal(b, []byte("ok")) {
|
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
if l := len(fingerprints); l != 4 {
|
|
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
|
}
|
|
}
|
|
|
|
func TestClient_GetCertificateRenewer(t *testing.T) {
|
|
reset := setMinCertDuration(1 * time.Second)
|
|
defer reset()
|
|
|
|
// Start CA
|
|
ca := startCATestServer()
|
|
defer ca.Close()
|
|
|
|
client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second)
|
|
|
|
badOption := func(ctx *TLSOptionCtx) error {
|
|
return errors.New("foo")
|
|
}
|
|
|
|
type args struct {
|
|
sign *api.SignResponse
|
|
pk crypto.PrivateKey
|
|
options []TLSOption
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
wantErr bool
|
|
}{
|
|
{"ok", args{sr, pk, nil}, false},
|
|
{"bad-pk", args{sr, []byte("foo"), nil}, true},
|
|
{"bad-option", args{sr, pk, []TLSOption{badOption}}, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := client.GetCertificateRenewer(tt.args.sign, tt.args.pk, tt.args.options...)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Client.GetCertificateRenewer() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if tt.wantErr == false {
|
|
fn := got.RenewCertificate
|
|
var i int
|
|
got.RenewCertificate = func() (*tls.Certificate, error) {
|
|
cert, err := fn()
|
|
if err != nil {
|
|
t.Errorf("TLSRenewer.RenewCertificate() error = %v", err)
|
|
} else {
|
|
i++
|
|
}
|
|
return cert, err
|
|
}
|
|
got.Run()
|
|
time.Sleep(5 * time.Second)
|
|
if i == 0 {
|
|
t.Errorf("Client.GetCertificateRenewer() certificate was not renewed")
|
|
}
|
|
got.Stop()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCertificate(t *testing.T) {
|
|
cert := parseCertificate(certPEM)
|
|
ok := &api.SignResponse{
|
|
ServerPEM: api.Certificate{Certificate: cert},
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
sign *api.SignResponse
|
|
want *x509.Certificate
|
|
wantErr bool
|
|
}{
|
|
{"ok", ok, cert, false},
|
|
{"fail", &api.SignResponse{}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := Certificate(tt.sign)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Certificate() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("Certificate() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIntermediateCertificate(t *testing.T) {
|
|
intermediate := parseCertificate(rootPEM)
|
|
ok := &api.SignResponse{
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
|
CaPEM: api.Certificate{Certificate: intermediate},
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
sign *api.SignResponse
|
|
want *x509.Certificate
|
|
wantErr bool
|
|
}{
|
|
{"ok", ok, intermediate, false},
|
|
{"fail", &api.SignResponse{}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := IntermediateCertificate(tt.sign)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("IntermediateCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("IntermediateCertificate() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRootCertificateCertificate(t *testing.T) {
|
|
root := parseCertificate(rootPEM)
|
|
ok := &api.SignResponse{
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
|
TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{
|
|
{root, root},
|
|
}},
|
|
}
|
|
noTLS := &api.SignResponse{
|
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
|
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
sign *api.SignResponse
|
|
want *x509.Certificate
|
|
wantErr bool
|
|
}{
|
|
{"ok", ok, root, false},
|
|
{"fail", &api.SignResponse{}, nil, true},
|
|
{"no tls", noTLS, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := RootCertificate(tt.sign)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("RootCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("RootCertificate() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|