diff --git a/x509util/certificate_test.go b/x509util/certificate_test.go index ea08de28..140578ec 100644 --- a/x509util/certificate_test.go +++ b/x509util/certificate_test.go @@ -7,11 +7,14 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/asn1" + "fmt" + "io" "math/big" "net" "net/url" "reflect" "testing" + "time" ) func createCertificateRequest(t *testing.T, commonName string, sans []string) (*x509.CertificateRequest, crypto.Signer) { @@ -39,6 +42,68 @@ func createCertificateRequest(t *testing.T, commonName string, sans []string) (* return cr, priv } +func createIssuerCertificate(t *testing.T, commonName string) (*x509.Certificate, crypto.Signer) { + t.Helper() + now := time.Now() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + subjectKeyID, err := generateSubjectKeyID(pub) + if err != nil { + t.Fatal(err) + } + sn, err := generateSerialNumber() + if err != nil { + t.Fatal(err) + } + + template := &x509.Certificate{ + IsCA: true, + NotBefore: now, + NotAfter: now.Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + MaxPathLen: 0, + MaxPathLenZero: true, + Issuer: pkix.Name{CommonName: "issuer"}, + Subject: pkix.Name{CommonName: "issuer"}, + SerialNumber: sn, + SubjectKeyId: subjectKeyID, + } + asn1Data, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv) + if err != nil { + t.Fatal(err) + } + crt, err := x509.ParseCertificate(asn1Data) + if err != nil { + t.Fatal(err) + } + return crt, priv +} + +type badSigner struct { + pub crypto.PublicKey +} + +func createBadSigner(t *testing.T) *badSigner { + pub, _, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + return &badSigner{ + pub: pub, + } +} + +func (b *badSigner) Public() crypto.PublicKey { + return b.pub +} + +func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + return nil, fmt.Errorf("💥") +} + func TestNewCertificate(t *testing.T) { cr, priv := createCertificateRequest(t, "commonName", []string{"foo.com"}) crBadSignateure, _ := createCertificateRequest(t, "fail", []string{"foo.com"}) @@ -222,3 +287,76 @@ func TestCertificate_GetCertificate(t *testing.T) { }) } } + +func TestCreateCertificate(t *testing.T) { + iss, issPriv := createIssuerCertificate(t, "issuer") + + mustSerialNumber := func() *big.Int { + sn, err := generateSerialNumber() + if err != nil { + t.Fatal(err) + } + return sn + } + mustSubjectKeyID := func(pub crypto.PublicKey) []byte { + b, err := generateSubjectKeyID(pub) + if err != nil { + t.Fatal(err) + } + return b + } + + cr1, priv1 := createCertificateRequest(t, "commonName", []string{"foo.com"}) + crt1 := newCertificateRequest(cr1).GetLeafCertificate().GetCertificate() + crt1.SerialNumber = mustSerialNumber() + crt1.SubjectKeyId = mustSubjectKeyID(priv1.Public()) + + cr2, priv2 := createCertificateRequest(t, "commonName", []string{"foo.com"}) + crt2 := newCertificateRequest(cr2).GetLeafCertificate().GetCertificate() + crt2.SerialNumber = mustSerialNumber() + + cr3, priv3 := createCertificateRequest(t, "commonName", []string{"foo.com"}) + crt3 := newCertificateRequest(cr3).GetLeafCertificate().GetCertificate() + crt3.SubjectKeyId = mustSubjectKeyID(priv1.Public()) + + cr4, priv4 := createCertificateRequest(t, "commonName", []string{"foo.com"}) + crt4 := newCertificateRequest(cr4).GetLeafCertificate().GetCertificate() + + cr5, _ := createCertificateRequest(t, "commonName", []string{"foo.com"}) + crt5 := newCertificateRequest(cr5).GetLeafCertificate().GetCertificate() + + badSigner := createBadSigner(t) + + type args struct { + template *x509.Certificate + parent *x509.Certificate + pub crypto.PublicKey + signer crypto.Signer + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{crt1, iss, priv1.Public(), issPriv}, false}, + {"okNoSubjectKeyID", args{crt2, iss, priv2.Public(), issPriv}, false}, + {"okNoSerialNumber", args{crt3, iss, priv3.Public(), issPriv}, false}, + {"okNothing", args{crt4, iss, priv4.Public(), issPriv}, false}, + {"failSubjectKeyID", args{crt5, iss, []byte("foo"), issPriv}, true}, + {"failSign", args{crt1, iss, priv1.Public(), badSigner}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CreateCertificate(tt.args.template, tt.args.parent, tt.args.pub, tt.args.signer) + if (err != nil) != tt.wantErr { + t.Errorf("CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if err := got.CheckSignatureFrom(iss); err != nil { + t.Errorf("Certificate.CheckSignatureFrom() error = %v", err) + } + } + }) + } +}