diff --git a/authority/root.go b/authority/root.go index 98974904..cfd63595 100644 --- a/authority/root.go +++ b/authority/root.go @@ -34,7 +34,7 @@ func (a *Authority) GetRootCertificates() []*x509.Certificate { } // GetRoots returns all the root certificates for this CA. -func (a *Authority) GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error) { +func (a *Authority) GetRoots(peer *x509.Certificate) ([]*x509.Certificate, error) { // Check step provisioner extensions if err := a.authorizeRenewal(peer); err != nil { return nil, err diff --git a/authority/root_test.go b/authority/root_test.go index d9803d8e..9b80cad6 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -1,11 +1,16 @@ package authority import ( + "crypto/x509" "net/http" + "reflect" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/cli/crypto/keys" + "github.com/smallstep/cli/crypto/pemutil" + "github.com/smallstep/cli/crypto/x509util" ) func TestRoot(t *testing.T) { @@ -43,3 +48,160 @@ func TestRoot(t *testing.T) { }) } } + +func TestAuthority_GetRootCertificate(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + want *x509.Certificate + }{ + {"ok", cert}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + if got := a.GetRootCertificate(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetRootCertificate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetRootCertificates(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + want []*x509.Certificate + }{ + {"ok", []*x509.Certificate{cert}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + if got := a.GetRootCertificates(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetRootCertificates() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetRoots(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + a := testAuthority(t) + pub, _, err := keys.GenerateDefaultKeyPair() + assert.FatalError(t, err) + leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, + withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test")) + assert.FatalError(t, err) + crtBytes, err := leaf.CreateCertificate() + assert.FatalError(t, err) + crt, err := x509.ParseCertificate(crtBytes) + assert.FatalError(t, err) + + leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, + withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"), + withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), + ) + assert.FatalError(t, err) + crtFailBytes, err := leafFail.CreateCertificate() + assert.FatalError(t, err) + crtFail, err := x509.ParseCertificate(crtFailBytes) + assert.FatalError(t, err) + + type args struct { + peer *x509.Certificate + } + tests := []struct { + name string + args args + want []*x509.Certificate + wantErr bool + }{ + {"ok", args{crt}, []*x509.Certificate{cert}, false}, + {"fail", args{crtFail}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := a.GetRoots(tt.args.peer) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetRoots() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetFederation(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + a := testAuthority(t) + pub, _, err := keys.GenerateDefaultKeyPair() + assert.FatalError(t, err) + leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, + withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test")) + assert.FatalError(t, err) + crtBytes, err := leaf.CreateCertificate() + assert.FatalError(t, err) + crt, err := x509.ParseCertificate(crtBytes) + assert.FatalError(t, err) + + leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, + withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"), + withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), + ) + assert.FatalError(t, err) + crtFailBytes, err := leafFail.CreateCertificate() + assert.FatalError(t, err) + crtFail, err := x509.ParseCertificate(crtFailBytes) + assert.FatalError(t, err) + + type args struct { + peer *x509.Certificate + } + tests := []struct { + name string + args args + wantFederation []*x509.Certificate + wantErr bool + fn func() + }{ + {"ok", args{crt}, []*x509.Certificate{cert}, false, nil}, + {"fail", args{crtFail}, nil, true, nil}, + {"fail not a certificate", args{crt}, nil, true, func() { + a.certificates.Store("foo", "bar") + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.fn != nil { + tt.fn() + } + gotFederation, err := a.GetFederation(tt.args.peer) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotFederation, tt.wantFederation) { + t.Errorf("Authority.GetFederation() = %v, want %v", gotFederation, tt.wantFederation) + } + }) + } +}