diff --git a/cas/softcas/softcas_test.go b/cas/softcas/softcas_test.go index 1fd8248a..258ebf5c 100644 --- a/cas/softcas/softcas_test.go +++ b/cas/softcas/softcas_test.go @@ -7,16 +7,19 @@ import ( "crypto/rand" "crypto/x509" "crypto/x509/pkix" + "fmt" "io" "math/big" "reflect" "testing" "time" + "github.com/pkg/errors" + "github.com/smallstep/certificates/cas/apiv1" + "github.com/smallstep/certificates/kms" + kmsapi "github.com/smallstep/certificates/kms/apiv1" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" - - "github.com/smallstep/certificates/cas/apiv1" ) var ( @@ -36,6 +39,7 @@ MC4CAQAwBQYDK2VwBCIEII9ZckcrDKlbhZKR0jp820Uz6mOMLFsq2JhI+Tl7WJwH ) var ( + errTest = errors.New("test error") testIssuer = mustIssuer() testSigner = mustSigner() testTemplate = &x509.Certificate{ @@ -43,13 +47,83 @@ var ( DNSNames: []string{"test.smallstep.com"}, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, - PublicKey: mustSigner().Public(), + PublicKey: testSigner.Public(), SerialNumber: big.NewInt(1234), } - testNow = time.Now() - testSignedTemplate = mustSign(testTemplate, testNow, testNow.Add(24*time.Hour)) + testRootTemplate = &x509.Certificate{ + Subject: pkix.Name{CommonName: "Test Root CA"}, + KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, + PublicKey: testSigner.Public(), + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + SerialNumber: big.NewInt(1234), + } + testIntermediateTemplate = &x509.Certificate{ + Subject: pkix.Name{CommonName: "Test Intermediate CA"}, + KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, + PublicKey: testSigner.Public(), + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 0, + MaxPathLenZero: true, + SerialNumber: big.NewInt(1234), + } + testNow = time.Now() + testSignedTemplate = mustSign(testTemplate, testIssuer, testNow, testNow.Add(24*time.Hour)) + testSignedRootTemplate = mustSign(testRootTemplate, testRootTemplate, testNow, testNow.Add(24*time.Hour)) + testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) ) +type mockKeyManager struct { + signer crypto.Signer + errGetPublicKey error + errCreateKey error + errCreatesigner error + errClose error +} + +func (m *mockKeyManager) GetPublicKey(req *kmsapi.GetPublicKeyRequest) (crypto.PublicKey, error) { + signer := testSigner + if m.signer != nil { + signer = m.signer + } + return signer.Public(), m.errGetPublicKey +} + +func (m *mockKeyManager) CreateKey(req *kmsapi.CreateKeyRequest) (*kmsapi.CreateKeyResponse, error) { + signer := testSigner + if m.signer != nil { + signer = m.signer + } + return &kmsapi.CreateKeyResponse{ + PrivateKey: signer, + PublicKey: signer.Public(), + }, m.errCreateKey +} + +func (m *mockKeyManager) CreateSigner(req *kmsapi.CreateSignerRequest) (crypto.Signer, error) { + signer := testSigner + if m.signer != nil { + signer = m.signer + } + return signer, m.errCreatesigner +} + +func (m *mockKeyManager) Close() error { + return m.errClose +} + +type badSigner struct{} + +func (b *badSigner) Public() crypto.PublicKey { + return testSigner.Public() +} + +func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + return nil, fmt.Errorf("💥") +} + func mockNow(t *testing.T) { tmp := now now = func() time.Time { @@ -76,12 +150,12 @@ func mustSigner() crypto.Signer { return v.(crypto.Signer) } -func mustSign(template *x509.Certificate, notBefore, notAfter time.Time) *x509.Certificate { +func mustSign(template, parent *x509.Certificate, notBefore, notAfter time.Time) *x509.Certificate { tmpl := *template tmpl.NotBefore = notBefore tmpl.NotAfter = notAfter - tmpl.Issuer = testIssuer.Subject - cert, err := x509util.CreateCertificate(&tmpl, testIssuer, tmpl.PublicKey, testSigner) + tmpl.Issuer = parent.Subject + cert, err := x509util.CreateCertificate(&tmpl, parent, tmpl.PublicKey, testSigner) if err != nil { panic(err) } @@ -343,3 +417,125 @@ func Test_now(t *testing.T) { t.Errorf("now() = %s, want ~%s", t1, t0) } } + +func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { + mockNow(t) + + type fields struct { + Issuer *x509.Certificate + Signer crypto.Signer + KeyManager kms.KeyManager + } + type args struct { + req *apiv1.CreateCertificateAuthorityRequest + } + tests := []struct { + name string + fields fields + args args + want *apiv1.CreateCertificateAuthorityResponse + wantErr bool + }{ + {"ok root", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: testRootTemplate, + Lifetime: 24 * time.Hour, + }}, &apiv1.CreateCertificateAuthorityResponse{ + Name: "Test Root CA", + Certificate: testSignedRootTemplate, + PublicKey: testSignedRootTemplate.PublicKey, + PrivateKey: testSigner, + Signer: testSigner, + }, false}, + {"ok intermediate", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Certificate: testSignedRootTemplate, + Signer: testSigner, + }, + }}, &apiv1.CreateCertificateAuthorityResponse{ + Name: "Test Intermediate CA", + Certificate: testSignedIntermediateTemplate, + CertificateChain: []*x509.Certificate{testSignedRootTemplate}, + PublicKey: testSignedIntermediateTemplate.PublicKey, + PrivateKey: testSigner, + Signer: testSigner, + }, false}, + {"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail lifetime", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: testIntermediateTemplate, + }}, nil, true}, + {"fail type", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail parent", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail parent.certificate", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Signer: testSigner, + }, + }}, nil, true}, + {"fail parent.signer", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Certificate: testSignedRootTemplate, + }, + }}, nil, true}, + {"fail createKey", fields{nil, nil, &mockKeyManager{errCreateKey: errTest}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail createSigner", fields{nil, nil, &mockKeyManager{errCreatesigner: errTest}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail sign root", fields{nil, nil, &mockKeyManager{signer: &badSigner{}}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail sign intermediate", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Certificate: testSignedRootTemplate, + Signer: &badSigner{}, + }, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &SoftCAS{ + Issuer: tt.fields.Issuer, + Signer: tt.fields.Signer, + KeyManager: tt.fields.KeyManager, + } + got, err := c.CreateCertificateAuthority(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("SoftCAS.CreateCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SoftCAS.CreateCertificateAuthority() = \n%#v, want \n%#v", got, tt.want) + } + }) + } +}