Add new identity tests.

This commit is contained in:
Mariano Cano 2021-04-21 18:07:59 -07:00
parent e414d0c8ea
commit 50b9aaec57
2 changed files with 115 additions and 7 deletions

View file

@ -63,6 +63,7 @@ func TestIdentity_Kind(t *testing.T) {
}{ }{
{"disabled", fields{""}, Disabled}, {"disabled", fields{""}, Disabled},
{"mutualTLS", fields{"mTLS"}, MutualTLS}, {"mutualTLS", fields{"mTLS"}, MutualTLS},
{"tunnelTLS", fields{"tTLS"}, TunnelTLS},
{"unknown", fields{"unknown"}, Type("unknown")}, {"unknown", fields{"unknown"}, Type("unknown")},
} }
for _, tt := range tests { for _, tt := range tests {
@ -82,19 +83,27 @@ func TestIdentity_Validate(t *testing.T) {
Type string Type string
Certificate string Certificate string
Key string Key string
Host string
Root string
} }
tests := []struct { tests := []struct {
name string name string
fields fields fields fields
wantErr bool wantErr bool
}{ }{
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, false}, {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false},
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false},
{"ok disabled", fields{}, false}, {"ok disabled", fields{}, false},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, true}, {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true},
{"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key"}, true}, {"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true},
{"fail key", fields{"mTLS", "testdata/identity/identity.crt", ""}, true}, {"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true},
{"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key"}, true}, {"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key"}, true}, {"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true},
{"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true},
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true},
{"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -102,6 +111,8 @@ func TestIdentity_Validate(t *testing.T) {
Type: tt.fields.Type, Type: tt.fields.Type,
Certificate: tt.fields.Certificate, Certificate: tt.fields.Certificate,
Key: tt.fields.Key, Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
} }
if err := i.Validate(); (err != nil) != tt.wantErr { if err := i.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr)
@ -127,7 +138,8 @@ func TestIdentity_TLSCertificate(t *testing.T) {
want tls.Certificate want tls.Certificate
wantErr bool wantErr bool
}{ }{
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok disabled", fields{}, tls.Certificate{}, false}, {"ok disabled", fields{}, tls.Certificate{}, false},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
{"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
@ -255,6 +267,95 @@ func TestWriteDefaultIdentity(t *testing.T) {
} }
} }
func TestIdentity_GetClientCertificateFunc(t *testing.T) {
expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key")
if err != nil {
t.Fatal(err)
}
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
want *tls.Certificate
wantErr bool
}{
{"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false},
{"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false},
{"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true},
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := &Identity{
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
fn := i.GetClientCertificateFunc()
got, err := fn(&tls.CertificateRequestInfo{})
if (err != nil) != tt.wantErr {
t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want)
}
})
}
}
func TestIdentity_GetCertPool(t *testing.T) {
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
wantSubjects [][]byte
wantErr bool
}{
{"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false},
{"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false},
{"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true},
{"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := &Identity{
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
got, err := i.GetCertPool()
if (err != nil) != tt.wantErr {
t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
subjects := got.Subjects()
if !reflect.DeepEqual(subjects, tt.wantSubjects) {
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)
}
}
})
}
}
type renewer struct { type renewer struct {
pool *x509.CertPool pool *x509.CertPool
sign *api.SignResponse sign *api.SignResponse

View file

@ -0,0 +1,7 @@
{
"type": "mTLS",
"crt": "testdata/identity/identity.crt",
"key": "testdata/identity/identity_key",
"host": "tunnel:443",
"root": "testdata/certs/root_ca.crt"
}