From 50b9aaec57db7d42d4c5997024fa8090c6aba27c Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 21 Apr 2021 18:07:59 -0700 Subject: [PATCH] Add new identity tests. --- ca/identity/identity_test.go | 115 ++++++++++++++++++++++-- ca/identity/testdata/config/tunnel.json | 7 ++ 2 files changed, 115 insertions(+), 7 deletions(-) create mode 100644 ca/identity/testdata/config/tunnel.json diff --git a/ca/identity/identity_test.go b/ca/identity/identity_test.go index 7064cead..ce64768c 100644 --- a/ca/identity/identity_test.go +++ b/ca/identity/identity_test.go @@ -63,6 +63,7 @@ func TestIdentity_Kind(t *testing.T) { }{ {"disabled", fields{""}, Disabled}, {"mutualTLS", fields{"mTLS"}, MutualTLS}, + {"tunnelTLS", fields{"tTLS"}, TunnelTLS}, {"unknown", fields{"unknown"}, Type("unknown")}, } for _, tt := range tests { @@ -82,19 +83,27 @@ func TestIdentity_Validate(t *testing.T) { Type string Certificate string Key string + Host string + Root string } tests := []struct { name string fields fields wantErr bool }{ - {"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, false}, + {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false}, + {"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false}, {"ok disabled", fields{}, false}, - {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, true}, - {"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key"}, true}, - {"fail key", fields{"mTLS", "testdata/identity/identity.crt", ""}, true}, - {"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key"}, true}, - {"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key"}, true}, + {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true}, + {"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true}, + {"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true}, + {"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, + {"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true}, + {"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, + {"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true}, + {"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, + {"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true}, + {"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -102,6 +111,8 @@ func TestIdentity_Validate(t *testing.T) { Type: tt.fields.Type, Certificate: tt.fields.Certificate, Key: tt.fields.Key, + Host: tt.fields.Host, + Root: tt.fields.Root, } if err := i.Validate(); (err != nil) != tt.wantErr { t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr) @@ -127,7 +138,8 @@ func TestIdentity_TLSCertificate(t *testing.T) { want tls.Certificate wantErr bool }{ - {"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, + {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, + {"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, {"ok disabled", fields{}, tls.Certificate{}, false}, {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, @@ -255,6 +267,95 @@ func TestWriteDefaultIdentity(t *testing.T) { } } +func TestIdentity_GetClientCertificateFunc(t *testing.T) { + expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key") + if err != nil { + t.Fatal(err) + } + + type fields struct { + Type string + Certificate string + Key string + Host string + Root string + } + tests := []struct { + name string + fields fields + want *tls.Certificate + wantErr bool + }{ + {"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false}, + {"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false}, + {"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true}, + {"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &Identity{ + Type: tt.fields.Type, + Certificate: tt.fields.Certificate, + Key: tt.fields.Key, + Host: tt.fields.Host, + Root: tt.fields.Root, + } + fn := i.GetClientCertificateFunc() + got, err := fn(&tls.CertificateRequestInfo{}) + if (err != nil) != tt.wantErr { + t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIdentity_GetCertPool(t *testing.T) { + type fields struct { + Type string + Certificate string + Key string + Host string + Root string + } + tests := []struct { + name string + fields fields + wantSubjects [][]byte + wantErr bool + }{ + {"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false}, + {"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false}, + {"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true}, + {"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &Identity{ + Type: tt.fields.Type, + Certificate: tt.fields.Certificate, + Key: tt.fields.Key, + Host: tt.fields.Host, + Root: tt.fields.Root, + } + got, err := i.GetCertPool() + if (err != nil) != tt.wantErr { + t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil { + subjects := got.Subjects() + if !reflect.DeepEqual(subjects, tt.wantSubjects) { + t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects) + } + } + + }) + } +} + type renewer struct { pool *x509.CertPool sign *api.SignResponse diff --git a/ca/identity/testdata/config/tunnel.json b/ca/identity/testdata/config/tunnel.json new file mode 100644 index 00000000..49c76a55 --- /dev/null +++ b/ca/identity/testdata/config/tunnel.json @@ -0,0 +1,7 @@ +{ + "type": "mTLS", + "crt": "testdata/identity/identity.crt", + "key": "testdata/identity/identity_key", + "host": "tunnel:443", + "root": "testdata/certs/root_ca.crt" +} \ No newline at end of file