From 44207523be1aefbcc056f6a67d0c27a5bc41e9fe Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 21 Jul 2020 14:06:21 -0700 Subject: [PATCH] Add missing tests. --- ca/tls.go | 4 +- x509util/marshal_utils_test.go | 300 +++++++++++++++++++++++++++++++++ x509util/name_test.go | 284 +++++++++++++++++++++++++++++++ x509util/options_test.go | 237 ++++++++++++++++++++++++++ x509util/testdata/example.tpl | 21 +++ 5 files changed, 844 insertions(+), 2 deletions(-) create mode 100644 x509util/marshal_utils_test.go create mode 100644 x509util/name_test.go create mode 100644 x509util/options_test.go create mode 100644 x509util/testdata/example.tpl diff --git a/ca/tls.go b/ca/tls.go index ffae68e5..20a5e504 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -56,7 +56,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, return nil, nil, err } // Use mutable tls.Config on renew - tr.DialTLS = c.buildDialTLS(tlsCtx) //nolint:deprecated + tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) @@ -108,7 +108,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, return nil, err } // Use mutable tls.Config on renew - tr.DialTLS = c.buildDialTLS(tlsCtx) //nolint:deprecated + tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) diff --git a/x509util/marshal_utils_test.go b/x509util/marshal_utils_test.go new file mode 100644 index 00000000..b17f1f72 --- /dev/null +++ b/x509util/marshal_utils_test.go @@ -0,0 +1,300 @@ +package x509util + +import ( + "encoding/asn1" + "encoding/json" + "net" + "net/url" + "reflect" + "testing" +) + +func TestMultiString_MarshalJSON(t *testing.T) { + tests := []struct { + name string + m MultiString + want []byte + wantErr bool + }{ + {"ok", []string{"foo", "bar"}, []byte(`["foo","bar"]`), false}, + {"empty", []string{}, []byte(`[]`), false}, + {"nil", nil, []byte(`null`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.m) + if (err != nil) != tt.wantErr { + t.Errorf("MultiIPNet.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiIPNet.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiString_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + want MultiString + wantErr bool + }{ + {"string", args{[]byte(`"foo"`)}, []string{"foo"}, false}, + {"array", args{[]byte(`["foo", "bar", "zar"]`)}, []string{"foo", "bar", "zar"}, false}, + {"empty", args{[]byte(`[]`)}, []string{}, false}, + {"null", args{[]byte(`null`)}, nil, false}, + {"fail", args{[]byte(`["foo"`)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got MultiString + if err := got.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("MultiString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiString.UnmarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiIP_MarshalJSON(t *testing.T) { + tests := []struct { + name string + m MultiIP + want []byte + wantErr bool + }{ + {"ok", []net.IP{net.ParseIP("::1"), net.ParseIP("1.2.3.4")}, []byte(`["::1","1.2.3.4"]`), false}, + {"empty", []net.IP{}, []byte(`[]`), false}, + {"nil", nil, []byte(`null`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.m) + if (err != nil) != tt.wantErr { + t.Errorf("MultiIPNet.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiIPNet.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiIP_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + want MultiIP + wantErr bool + }{ + {"string", args{[]byte(`"::1"`)}, []net.IP{net.ParseIP("::1")}, false}, + {"array", args{[]byte(`["127.0.0.1", "::1"]`)}, []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}, false}, + {"empty", args{[]byte(`[]`)}, []net.IP{}, false}, + {"null", args{[]byte(`null`)}, nil, false}, + {"fail", args{[]byte(`"foo.bar"`)}, nil, true}, + {"failJSON", args{[]byte(`["::1"`)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got MultiIP + if err := got.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("MultiIP.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiIP.UnmarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiIPNet_MarshalJSON(t *testing.T) { + ipNet := func(s string) *net.IPNet { + _, ipNet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + return ipNet + } + + tests := []struct { + name string + m MultiIPNet + want []byte + wantErr bool + }{ + {"ok", []*net.IPNet{ipNet("1.1.0.0/16"), ipNet("2001:db8:8a2e:7334::/64")}, []byte(`["1.1.0.0/16","2001:db8:8a2e:7334::/64"]`), false}, + {"empty", []*net.IPNet{}, []byte(`[]`), false}, + {"nil", nil, []byte(`null`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.m.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MultiIPNet.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiIPNet.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiIPNet_UnmarshalJSON(t *testing.T) { + ipNet := func(s string) *net.IPNet { + _, ipNet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + return ipNet + } + + type args struct { + data []byte + } + tests := []struct { + name string + args args + want MultiIPNet + wantErr bool + }{ + {"string", args{[]byte(`"1.1.0.0/16"`)}, []*net.IPNet{ipNet("1.1.0.0/16")}, false}, + {"array", args{[]byte(`["1.0.0.0/24", "2.1.0.0/16"]`)}, []*net.IPNet{ipNet("1.0.0.0/24"), ipNet("2.1.0.0/16")}, false}, + {"empty", args{[]byte(`[]`)}, []*net.IPNet{}, false}, + {"null", args{[]byte(`null`)}, nil, false}, + {"fail", args{[]byte(`"foo.bar/16"`)}, nil, true}, + {"failJSON", args{[]byte(`["1.0.0.0/24"`)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got MultiIPNet + if err := got.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("MultiIPNet.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiIPNet.UnmarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiURL_MarshalJSON(t *testing.T) { + tests := []struct { + name string + m MultiURL + want []byte + wantErr bool + }{ + {"ok", []*url.URL{{Scheme: "https", Host: "iss", Fragment: "sub"}, {Scheme: "uri", Opaque: "foo:bar"}}, []byte(`["https://iss#sub","uri:foo:bar"]`), false}, + {"empty", []*url.URL{}, []byte(`[]`), false}, + {"nil", nil, []byte(`null`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.m.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MultiURL.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiURL.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiURL_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + want MultiURL + wantErr bool + }{ + {"string", args{[]byte(`"https://iss#sub"`)}, []*url.URL{{Scheme: "https", Host: "iss", Fragment: "sub"}}, false}, + {"array", args{[]byte(`["https://iss#sub", "uri:foo:bar"]`)}, []*url.URL{{Scheme: "https", Host: "iss", Fragment: "sub"}, {Scheme: "uri", Opaque: "foo:bar"}}, false}, + {"empty", args{[]byte(`[]`)}, []*url.URL{}, false}, + {"null", args{[]byte(`null`)}, nil, false}, + {"fail", args{[]byte(`":foo:bar"`)}, nil, true}, + {"failJSON", args{[]byte(`["https://iss#sub"`)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got MultiURL + if err := got.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("MultiURL.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiURL.UnmarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiObjectIdentifier_MarshalJSON(t *testing.T) { + tests := []struct { + name string + m MultiObjectIdentifier + want []byte + wantErr bool + }{ + {"ok", []asn1.ObjectIdentifier{[]int{1, 2, 3, 4}, []int{5, 6, 7, 8, 9, 0}}, []byte(`["1.2.3.4","5.6.7.8.9.0"]`), false}, + {"empty", []asn1.ObjectIdentifier{}, []byte(`[]`), false}, + {"nil", nil, []byte(`null`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.m) + if (err != nil) != tt.wantErr { + t.Errorf("MultiObjectIdentifier.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiObjectIdentifier.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiObjectIdentifier_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + want MultiObjectIdentifier + wantErr bool + }{ + {"string", args{[]byte(`"1.2.3.4"`)}, []asn1.ObjectIdentifier{[]int{1, 2, 3, 4}}, false}, + {"array", args{[]byte(`["1.2.3.4", "5.6.7.8.9.0"]`)}, []asn1.ObjectIdentifier{[]int{1, 2, 3, 4}, []int{5, 6, 7, 8, 9, 0}}, false}, + {"empty", args{[]byte(`[]`)}, []asn1.ObjectIdentifier{}, false}, + {"null", args{[]byte(`null`)}, nil, false}, + {"fail", args{[]byte(`":foo:bar"`)}, nil, true}, + {"failJSON", args{[]byte(`["https://iss#sub"`)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got MultiObjectIdentifier + if err := got.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("MultiObjectIdentifier.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MultiObjectIdentifier.UnmarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/x509util/name_test.go b/x509util/name_test.go new file mode 100644 index 00000000..1000f19a --- /dev/null +++ b/x509util/name_test.go @@ -0,0 +1,284 @@ +package x509util + +import ( + "crypto/x509" + "crypto/x509/pkix" + "reflect" + "testing" +) + +func TestName_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + want Name + wantErr bool + }{ + {"null", args{[]byte("null")}, Name{}, false}, + {"empty", args{[]byte("{}")}, Name{}, false}, + {"commonName", args{[]byte(`"commonName"`)}, Name{CommonName: "commonName"}, false}, + {"object", args{[]byte(`{ + "country": "The country", + "organization": "The organization", + "organizationalUnit": ["The organizationalUnit 1", "The organizationalUnit 2"], + "locality": ["The locality 1", "The locality 2"], + "province": "The province", + "streetAddress": "The streetAddress", + "postalCode": "The postalCode", + "serialNumber": "The serialNumber", + "commonName": "The commonName" + }`)}, Name{ + Country: []string{"The country"}, + Organization: []string{"The organization"}, + OrganizationalUnit: []string{"The organizationalUnit 1", "The organizationalUnit 2"}, + Locality: []string{"The locality 1", "The locality 2"}, + Province: []string{"The province"}, + StreetAddress: []string{"The streetAddress"}, + PostalCode: []string{"The postalCode"}, + SerialNumber: "The serialNumber", + CommonName: "The commonName", + }, false}, + {"number", args{[]byte("1234")}, Name{}, true}, + {"badJSON", args{[]byte("'badJSON'")}, Name{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Name + if err := got.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("Name.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Name.UnmarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_newSubject(t *testing.T) { + type args struct { + n pkix.Name + } + tests := []struct { + name string + args args + want Subject + }{ + {"ok", args{pkix.Name{ + Country: []string{"The country"}, + Organization: []string{"The organization"}, + OrganizationalUnit: []string{"The organizationalUnit 1", "The organizationalUnit 2"}, + Locality: []string{"The locality 1", "The locality 2"}, + Province: []string{"The province"}, + StreetAddress: []string{"The streetAddress"}, + PostalCode: []string{"The postalCode"}, + SerialNumber: "The serialNumber", + CommonName: "The commonName", + }}, Subject{ + Country: []string{"The country"}, + Organization: []string{"The organization"}, + OrganizationalUnit: []string{"The organizationalUnit 1", "The organizationalUnit 2"}, + Locality: []string{"The locality 1", "The locality 2"}, + Province: []string{"The province"}, + StreetAddress: []string{"The streetAddress"}, + PostalCode: []string{"The postalCode"}, + SerialNumber: "The serialNumber", + CommonName: "The commonName", + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newSubject(tt.args.n); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newSubject() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSubject_Set(t *testing.T) { + type fields struct { + Country MultiString + Organization MultiString + OrganizationalUnit MultiString + Locality MultiString + Province MultiString + StreetAddress MultiString + PostalCode MultiString + SerialNumber string + CommonName string + } + type args struct { + c *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + want *x509.Certificate + }{ + {"ok", fields{ + Country: []string{"The country"}, + Organization: []string{"The organization"}, + OrganizationalUnit: []string{"The organizationalUnit 1", "The organizationalUnit 2"}, + Locality: []string{"The locality 1", "The locality 2"}, + Province: []string{"The province"}, + StreetAddress: []string{"The streetAddress"}, + PostalCode: []string{"The postalCode"}, + SerialNumber: "The serialNumber", + CommonName: "The commonName", + }, args{&x509.Certificate{}}, &x509.Certificate{ + Subject: pkix.Name{ + Country: []string{"The country"}, + Organization: []string{"The organization"}, + OrganizationalUnit: []string{"The organizationalUnit 1", "The organizationalUnit 2"}, + Locality: []string{"The locality 1", "The locality 2"}, + Province: []string{"The province"}, + StreetAddress: []string{"The streetAddress"}, + PostalCode: []string{"The postalCode"}, + SerialNumber: "The serialNumber", + CommonName: "The commonName", + }, + }}, + {"overwrite", fields{ + CommonName: "The commonName", + }, args{&x509.Certificate{}}, &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "The commonName", + }, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := Subject{ + Country: tt.fields.Country, + Organization: tt.fields.Organization, + OrganizationalUnit: tt.fields.OrganizationalUnit, + Locality: tt.fields.Locality, + Province: tt.fields.Province, + StreetAddress: tt.fields.StreetAddress, + PostalCode: tt.fields.PostalCode, + SerialNumber: tt.fields.SerialNumber, + CommonName: tt.fields.CommonName, + } + s.Set(tt.args.c) + if !reflect.DeepEqual(tt.args.c, tt.want) { + t.Errorf("Subject.Set() = %v, want %v", tt.args.c, tt.want) + } + }) + } +} + +func Test_newIssuer(t *testing.T) { + type args struct { + n pkix.Name + } + tests := []struct { + name string + args args + want Issuer + }{ + {"ok", args{pkix.Name{ + Country: []string{"The country"}, + Organization: []string{"The organization"}, + OrganizationalUnit: []string{"The organizationalUnit 1", "The organizationalUnit 2"}, + Locality: []string{"The locality 1", "The locality 2"}, + Province: []string{"The province"}, + StreetAddress: []string{"The streetAddress"}, + PostalCode: []string{"The postalCode"}, + SerialNumber: "The serialNumber", + CommonName: "The commonName", + }}, Issuer{ + Country: []string{"The country"}, + Organization: []string{"The organization"}, + OrganizationalUnit: []string{"The organizationalUnit 1", "The organizationalUnit 2"}, + Locality: []string{"The locality 1", "The locality 2"}, + Province: []string{"The province"}, + StreetAddress: []string{"The streetAddress"}, + PostalCode: []string{"The postalCode"}, + SerialNumber: "The serialNumber", + CommonName: "The commonName", + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newIssuer(tt.args.n); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newIssuer() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIssuer_Set(t *testing.T) { + type fields struct { + Country MultiString + Organization MultiString + OrganizationalUnit MultiString + Locality MultiString + Province MultiString + StreetAddress MultiString + PostalCode MultiString + SerialNumber string + CommonName string + } + type args struct { + c *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + want *x509.Certificate + }{ + {"ok", fields{ + Country: []string{"The country"}, + Organization: []string{"The organization"}, + OrganizationalUnit: []string{"The organizationalUnit 1", "The organizationalUnit 2"}, + Locality: []string{"The locality 1", "The locality 2"}, + Province: []string{"The province"}, + StreetAddress: []string{"The streetAddress"}, + PostalCode: []string{"The postalCode"}, + SerialNumber: "The serialNumber", + CommonName: "The commonName", + }, args{&x509.Certificate{}}, &x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"The country"}, + Organization: []string{"The organization"}, + OrganizationalUnit: []string{"The organizationalUnit 1", "The organizationalUnit 2"}, + Locality: []string{"The locality 1", "The locality 2"}, + Province: []string{"The province"}, + StreetAddress: []string{"The streetAddress"}, + PostalCode: []string{"The postalCode"}, + SerialNumber: "The serialNumber", + CommonName: "The commonName", + }, + }}, + {"overwrite", fields{ + CommonName: "The commonName", + }, args{&x509.Certificate{}}, &x509.Certificate{ + Issuer: pkix.Name{ + CommonName: "The commonName", + }, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := Issuer{ + Country: tt.fields.Country, + Organization: tt.fields.Organization, + OrganizationalUnit: tt.fields.OrganizationalUnit, + Locality: tt.fields.Locality, + Province: tt.fields.Province, + StreetAddress: tt.fields.StreetAddress, + PostalCode: tt.fields.PostalCode, + SerialNumber: tt.fields.SerialNumber, + CommonName: tt.fields.CommonName, + } + i.Set(tt.args.c) + if !reflect.DeepEqual(tt.args.c, tt.want) { + t.Errorf("Issuer.Set() = %v, want %v", tt.args.c, tt.want) + } + }) + } +} diff --git a/x509util/options_test.go b/x509util/options_test.go new file mode 100644 index 00000000..7d2ee725 --- /dev/null +++ b/x509util/options_test.go @@ -0,0 +1,237 @@ +package x509util + +import ( + "bytes" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "reflect" + "testing" + + "github.com/pkg/errors" +) + +func createRSACertificateRequest(t *testing.T, bits int, commonName string, sans []string) (*x509.CertificateRequest, crypto.Signer) { + dnsNames, ips, emails, uris := SplitSANs(sans) + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + t.Fatal(err) + } + asn1Data, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{ + Subject: pkix.Name{CommonName: commonName}, + DNSNames: dnsNames, + IPAddresses: ips, + EmailAddresses: emails, + URIs: uris, + SignatureAlgorithm: x509.SHA256WithRSAPSS, + }, priv) + if err != nil { + t.Fatal(err) + } + cr, err := x509.ParseCertificateRequest(asn1Data) + if err != nil { + t.Fatal(err) + } + return cr, priv +} + +func Test_getFuncMap_fail(t *testing.T) { + var failMesage string + fns := getFuncMap(&failMesage) + fail := fns["fail"].(func(s string) (string, error)) + s, err := fail("the fail message") + if err == nil { + t.Errorf("fail() error = %v, wantErr %v", err, errors.New("the fail message")) + } + if s != "" { + t.Errorf("fail() = \"%s\", want \"the fail message\"", s) + } + if failMesage != "the fail message" { + t.Errorf("fail() message = \"%s\", want \"the fail message\"", failMesage) + } +} + +func TestWithTemplate(t *testing.T) { + cr, _ := createCertificateRequest(t, "foo", []string{"foo.com", "foo@foo.com", "::1", "https://foo.com"}) + crRSA, _ := createRSACertificateRequest(t, 2048, "foo", []string{"foo.com", "foo@foo.com", "::1", "https://foo.com"}) + type args struct { + text string + data TemplateData + cr *x509.CertificateRequest + } + tests := []struct { + name string + args args + want Options + wantErr bool + }{ + {"leaf", args{DefaultLeafTemplate, TemplateData{ + SubjectKey: Subject{CommonName: "foo"}, + SANsKey: []SubjectAlternativeName{{Type: "dns", Value: "foo.com"}}, + }, cr}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "subject": {"commonName":"foo"}, + "sans": [{"type":"dns","value":"foo.com"}], + "keyUsage": ["digitalSignature"], + "extKeyUsage": ["serverAuth", "clientAuth"] +}`), + }, false}, + {"leafRSA", args{DefaultLeafTemplate, TemplateData{ + SubjectKey: Subject{CommonName: "foo"}, + SANsKey: []SubjectAlternativeName{{Type: "dns", Value: "foo.com"}}, + }, crRSA}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "subject": {"commonName":"foo"}, + "sans": [{"type":"dns","value":"foo.com"}], + "keyUsage": ["keyEncipherment", "digitalSignature"], + "extKeyUsage": ["serverAuth", "clientAuth"] +}`), + }, false}, + {"iid", args{DefaultIIDLeafTemplate, TemplateData{}, cr}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "subject": {"commonName":"foo"}, + "dnsNames": ["foo.com"], + "emailAddresses": ["foo@foo.com"], + "ipAddresses": ["::1"], + "uris": ["https://foo.com"], + "keyUsage": ["digitalSignature"], + "extKeyUsage": ["serverAuth", "clientAuth"] +}`), + }, false}, + {"iidRSAAndEnforced", args{DefaultIIDLeafTemplate, TemplateData{ + SANsKey: []SubjectAlternativeName{{Type: "dns", Value: "foo.com"}}, + }, crRSA}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "subject": {"commonName":"foo"}, + "sans": [{"type":"dns","value":"foo.com"}], + "keyUsage": ["keyEncipherment", "digitalSignature"], + "extKeyUsage": ["serverAuth", "clientAuth"] +}`), + }, false}, + {"fail", args{`{{ fail "a message" }}`, TemplateData{}, cr}, Options{}, true}, + {"error", args{`{{ mustHas 3 .Data }}`, TemplateData{ + "Data": 3, + }, cr}, Options{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Options + fn := WithTemplate(tt.args.text, tt.args.data) + if err := fn(tt.args.cr, &got); (err != nil) != tt.wantErr { + t.Errorf("WithTemplate() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("WithTemplate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWithTemplateBase64(t *testing.T) { + cr, _ := createCertificateRequest(t, "foo", []string{"foo.com", "foo@foo.com", "::1", "https://foo.com"}) + type args struct { + s string + data TemplateData + cr *x509.CertificateRequest + } + tests := []struct { + name string + args args + want Options + wantErr bool + }{ + {"leaf", args{base64.StdEncoding.EncodeToString([]byte(DefaultLeafTemplate)), TemplateData{ + SubjectKey: Subject{CommonName: "foo"}, + SANsKey: []SubjectAlternativeName{{Type: "dns", Value: "foo.com"}}, + }, cr}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "subject": {"commonName":"foo"}, + "sans": [{"type":"dns","value":"foo.com"}], + "keyUsage": ["digitalSignature"], + "extKeyUsage": ["serverAuth", "clientAuth"] +}`), + }, false}, + {"badBase64", args{"foobar", TemplateData{}, cr}, Options{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Options + fn := WithTemplateBase64(tt.args.s, tt.args.data) + if err := fn(tt.args.cr, &got); (err != nil) != tt.wantErr { + t.Errorf("WithTemplateBase64() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("WithTemplateBase64() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWithTemplateFile(t *testing.T) { + cr, _ := createCertificateRequest(t, "foo", []string{"foo.com", "foo@foo.com", "::1", "https://foo.com"}) + rsa2048, _ := createRSACertificateRequest(t, 2048, "foo", []string{"foo.com", "foo@foo.com", "::1", "https://foo.com"}) + rsa3072, _ := createRSACertificateRequest(t, 3072, "foo", []string{"foo.com", "foo@foo.com", "::1", "https://foo.com"}) + + data := TemplateData{ + SANsKey: []SubjectAlternativeName{ + {Type: "dns", Value: "foo.com"}, + {Type: "email", Value: "root@foo.com"}, + {Type: "ip", Value: "127.0.0.1"}, + {Type: "uri", Value: "uri:foo:bar"}, + }, + TokenKey: map[string]interface{}{ + "iss": "https://iss", + "sub": "sub", + }, + } + type args struct { + path string + data TemplateData + cr *x509.CertificateRequest + } + tests := []struct { + name string + args args + want Options + wantErr bool + }{ + {"example", args{"./testdata/example.tpl", data, cr}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "subject": {"commonName":"foo"}, + "sans": [{"type":"dns","value":"foo.com"},{"type":"email","value":"root@foo.com"},{"type":"ip","value":"127.0.0.1"},{"type":"uri","value":"uri:foo:bar"}], + "emailAddresses": ["foo@foo.com"], + "uris": "https://iss#sub", + "keyUsage": ["digitalSignature"], + "extKeyUsage": ["serverAuth", "clientAuth"] +}`), + }, false}, + {"exampleRSA3072", args{"./testdata/example.tpl", data, rsa3072}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "subject": {"commonName":"foo"}, + "sans": [{"type":"dns","value":"foo.com"},{"type":"email","value":"root@foo.com"},{"type":"ip","value":"127.0.0.1"},{"type":"uri","value":"uri:foo:bar"}], + "emailAddresses": ["foo@foo.com"], + "uris": "https://iss#sub", + "keyUsage": ["keyEncipherment", "digitalSignature"], + "extKeyUsage": ["serverAuth", "clientAuth"] +}`), + }, false}, + {"exampleRSA2048", args{"./testdata/example.tpl", data, rsa2048}, Options{}, true}, + {"missing", args{"./testdata/missing.tpl", data, cr}, Options{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Options + fn := WithTemplateFile(tt.args.path, tt.args.data) + if err := fn(tt.args.cr, &got); (err != nil) != tt.wantErr { + t.Errorf("WithTemplateFile() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("WithTemplateFile() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/x509util/testdata/example.tpl b/x509util/testdata/example.tpl new file mode 100644 index 00000000..286edb5c --- /dev/null +++ b/x509util/testdata/example.tpl @@ -0,0 +1,21 @@ +{ + "subject": {{ toJson .Insecure.CR.Subject }}, + "sans": {{ toJson .SANs }}, +{{- if .Insecure.CR.EmailAddresses }} + "emailAddresses": {{ toJson .Insecure.CR.EmailAddresses }}, +{{- end }} +{{- if .Token }} + "uris": "{{ .Token.iss }}#{{ .Token.sub }}", +{{- end }} +{{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + {{- if lt .Insecure.CR.PublicKey.Size 384 }} + {{ fail "Key length must be at least 3072 bits" }} + {{- end }} +{{- end }} +{{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + "keyUsage": ["keyEncipherment", "digitalSignature"], +{{- else }} + "keyUsage": ["digitalSignature"], +{{- end }} + "extKeyUsage": ["serverAuth", "clientAuth"] +} \ No newline at end of file