Use mTLS by default on SDK methods.

Add options to modify the tls.Config for different configurations.
Fixes #7
This commit is contained in:
Mariano Cano 2018-11-21 13:29:18 -08:00
parent bb03aadddf
commit d872f09910
9 changed files with 417 additions and 419 deletions

View file

@ -104,15 +104,8 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) (
return client, sr, pk
}
func TestClient_GetServerTLSConfig_http(t *testing.T) {
client, sr, pk := sign("127.0.0.1")
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
func serverHandler(t *testing.T, clientDomain string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != "/no-cert" {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
@ -129,18 +122,46 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
w.Header().Set("x-fingerprint", hex.EncodeToString(sum[:]))
}
w.Write([]byte("ok"))
}))
defer srv.Close()
})
}
func TestClient_GetServerTLSConfig_http(t *testing.T) {
clientDomain := "test.domain"
client, sr, pk := sign("127.0.0.1")
// Create mTLS server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close()
// Create TLS server
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close()
tests := []struct {
name string
path string
wantErr bool
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
wantErr map[string]bool
}{
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tr, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.Transport() error = %v", err)
@ -149,8 +170,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
return &http.Client{
Transport: tr,
}
}},
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
@ -164,8 +185,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
return &http.Client{
Transport: tr,
}
}},
{"ok with no cert", "/no-cert", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
root, err := RootCertificate(sr)
if err != nil {
t.Errorf("RootCertificate() error = %v", err)
@ -183,33 +204,38 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
return &http.Client{
Transport: tr,
}
}},
{"fail with default", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
return &http.Client{}
}},
}, map[string]bool{srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain)
cli := tt.getClient(t, client, sr, pk)
if cli != nil {
resp, err := cli.Get(srv.URL + tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
if cli == nil {
return
}
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := cli.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, wantErr)
return
}
if wantErr {
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
})
}
@ -224,44 +250,36 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
ca := startCATestServer()
defer ca.Close()
clientDomain := "test.domain"
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
// Start mTLS server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
fingerprints := make(map[string]struct{})
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close()
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
w.Write([]byte("ok"))
}))
defer srv.Close()
// Start TLS server
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close()
// Clients: transport and tlsConfig
// Transport
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.Transport() error = %v", err)
}
// Transport with tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
@ -271,227 +289,15 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true
tr2.DisableKeepAlives = true
tests := []struct {
name string
client *http.Client
}{
{"with transport", &http.Client{Transport: tr1}},
{"with tlsConfig", &http.Client{Transport: tr2}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
if l := len(fingerprints); l != 2 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
}
// Wait for renewal 40s == 1m-1m/3
log.Printf("Sleeping for %s ...\n", 40*time.Second)
time.Sleep(40 * time.Second)
for _, tt := range tests {
t.Run("renewed "+tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
if l := len(fingerprints); l != 4 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
}
}
func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
client, sr, pk := sign("127.0.0.1")
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
// No client cert
root, err := RootCertificate(sr)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
t.Fatalf("RootCertificate() error = %v", err)
}
clientDomain := "test.domain"
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != "/no-cert" {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
}
w.Write([]byte("ok"))
}))
defer srv.Close()
tests := []struct {
name string
path string
wantErr bool
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
}{
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tr, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.Transport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
return nil
}
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
{"fail with no cert", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
root, err := RootCertificate(sr)
if err != nil {
t.Errorf("RootCertificate() error = %v", err)
return nil
}
tlsConfig := getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain)
cli := tt.getClient(t, client, sr, pk)
if cli != nil {
resp, err := cli.Get(srv.URL + tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
}
})
}
}
func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
// Start CA
ca := startCATestServer()
defer ca.Close()
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
fingerprints := make(map[string]struct{})
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
w.Write([]byte("ok"))
}))
defer srv.Close()
// Clients: transport and tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.Transport() error = %v", err)
}
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2, err := getDefaultTransport(tlsConfig)
tlsConfig = getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr3, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
@ -499,34 +305,66 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true
tr2.DisableKeepAlives = true
tr3.DisableKeepAlives = true
tests := []struct {
name string
client *http.Client
name string
client *http.Client
wantErr map[string]bool
}{
{"with transport", &http.Client{Transport: tr1}},
{"with tlsConfig", &http.Client{Transport: tr2}},
{"with transport", &http.Client{Transport: tr1}, map[string]bool{
srvTLS.URL: false,
srvMTLS.URL: false,
}},
{"with tlsConfig", &http.Client{Transport: tr2}, map[string]bool{
srvTLS.URL: false,
srvMTLS.URL: false,
}},
{"with no ClientCert", &http.Client{Transport: tr3}, map[string]bool{
srvTLS.URL + "/no-cert": false,
srvMTLS.URL + "/no-cert": true,
}},
{"fail with default", &http.Client{}, map[string]bool{
srvTLS.URL + "/no-cert": true,
srvMTLS.URL + "/no-cert": true,
}},
}
// To count different cert fingerprints
fingerprints := map[string]struct{}{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := tt.client.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v", err)
return
}
if wantErr {
return
}
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
fingerprints[fp] = struct{}{}
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err)
return
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
return
}
})
}
})
}
if l := len(fingerprints); l != 2 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
t.Errorf("number of fingerprints unexpected, got %d, want 2", l)
}
// Wait for renewal 40s == 1m-1m/3
@ -535,17 +373,31 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
for _, tt := range tests {
t.Run("renewed "+tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := tt.client.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v", err)
return
}
if wantErr {
return
}
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
fingerprints[fp] = struct{}{}
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err)
return
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
return
}
})
}
})
}