diff --git a/authority/db_test.go b/authority/db_test.go index e3834b99..bd6b27ca 100644 --- a/authority/db_test.go +++ b/authority/db_test.go @@ -4,17 +4,20 @@ import ( "crypto/x509" "github.com/smallstep/certificates/db" + "golang.org/x/crypto/ssh" ) type MockAuthDB struct { - err error - ret1 interface{} - init func(*db.Config) (db.AuthDB, error) - isRevoked func(string) (bool, error) - revoke func(rci *db.RevokedCertificateInfo) error - storeCertificate func(crt *x509.Certificate) error - useToken func(id, tok string) (bool, error) - shutdown func() error + err error + ret1 interface{} + init func(*db.Config) (db.AuthDB, error) + isRevoked func(string) (bool, error) + revoke func(rci *db.RevokedCertificateInfo) error + storeCertificate func(crt *x509.Certificate) error + useToken func(id, tok string) (bool, error) + isSSHHost func(principal string) (bool, error) + storeSSHCertificate func(crt *ssh.Certificate) error + shutdown func() error } func (m *MockAuthDB) Init(c *db.Config) (db.AuthDB, error) { @@ -58,6 +61,20 @@ func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error { return m.err } +func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) { + if m.isSSHHost != nil { + return m.isSSHHost(principal) + } + return m.ret1.(bool), m.err +} + +func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error { + if m.storeSSHCertificate != nil { + return m.storeSSHCertificate(crt) + } + return m.err +} + func (m *MockAuthDB) Shutdown() error { if m.shutdown != nil { return m.shutdown() diff --git a/authority/ssh_test.go b/authority/ssh_test.go index ff0bb23c..872278d6 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -4,12 +4,15 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "encoding/base64" "fmt" + "reflect" "testing" "time" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/templates" "golang.org/x/crypto/ssh" ) @@ -253,3 +256,167 @@ func TestAuthority_SignSSHAddUser(t *testing.T) { }) } } + +func TestAuthority_GetSSHRoots(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + user, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + + key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + host, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + + type fields struct { + sshCAUserCerts []ssh.PublicKey + sshCAHostCerts []ssh.PublicKey + } + tests := []struct { + name string + fields fields + want *SSHKeys + wantErr bool + }{ + {"ok", fields{[]ssh.PublicKey{user}, []ssh.PublicKey{host}}, &SSHKeys{UserKeys: []ssh.PublicKey{user}, HostKeys: []ssh.PublicKey{host}}, false}, + {"nil", fields{}, &SSHKeys{}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + a.sshCAUserCerts = tt.fields.sshCAUserCerts + a.sshCAHostCerts = tt.fields.sshCAHostCerts + + got, err := a.GetSSHRoots() + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetSSHRoots() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetSSHFederation(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + user, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + + key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + host, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + + type fields struct { + sshCAUserFederatedCerts []ssh.PublicKey + sshCAHostFederatedCerts []ssh.PublicKey + } + tests := []struct { + name string + fields fields + want *SSHKeys + wantErr bool + }{ + {"ok", fields{[]ssh.PublicKey{user}, []ssh.PublicKey{host}}, &SSHKeys{UserKeys: []ssh.PublicKey{user}, HostKeys: []ssh.PublicKey{host}}, false}, + {"nil", fields{}, &SSHKeys{}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts + a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts + + got, err := a.GetSSHFederation() + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetSSHFederation() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetSSHConfig(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + user, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + userSigner, err := ssh.NewSignerFromSigner(key) + assert.FatalError(t, err) + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + + key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + host, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + hostSigner, err := ssh.NewSignerFromSigner(key) + assert.FatalError(t, err) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + + tmplConfig := &templates.Templates{ + SSH: &templates.SSHTemplates{ + User: []templates.Template{ + {Name: "known_host.tpl", Type: templates.File, TemplatePath: "./testdata/templates/known_hosts.tpl", Path: "ssh/known_host", Comment: "#"}, + }, + Host: []templates.Template{ + {Name: "ca.tpl", Type: templates.File, TemplatePath: "./testdata/templates/ca.tpl", Path: "/etc/ssh/ca.pub", Comment: "#"}, + }, + }, + Data: map[string]interface{}{ + "Step": &templates.Step{ + SSH: templates.StepSSH{ + UserKey: user, + HostKey: host, + }, + }, + }, + } + userOutput := []templates.Output{ + {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte(fmt.Sprintf("@cert-authority * %s %s", host.Type(), hostB64))}, + } + hostOutput := []templates.Output{ + {Name: "ca.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/ca.pub", Content: []byte(user.Type() + " " + userB64)}, + } + + type fields struct { + templates *templates.Templates + userSigner ssh.Signer + hostSigner ssh.Signer + } + type args struct { + typ string + data map[string]string + } + tests := []struct { + name string + fields fields + args args + want []templates.Output + wantErr bool + }{ + {"user", fields{tmplConfig, userSigner, hostSigner}, args{"user", nil}, userOutput, false}, + {"host", fields{tmplConfig, userSigner, hostSigner}, args{"host", nil}, hostOutput, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + a.config.Templates = tt.fields.templates + a.sshCAUserCertSignKey = tt.fields.userSigner + a.sshCAHostCertSignKey = tt.fields.hostSigner + + got, err := a.GetSSHConfig(tt.args.typ, tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetSSHConfig() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/testdata/templates/ca.tpl b/authority/testdata/templates/ca.tpl new file mode 100644 index 00000000..21235dd5 --- /dev/null +++ b/authority/testdata/templates/ca.tpl @@ -0,0 +1,4 @@ +{{.Step.SSH.UserKey.Type}} {{.Step.SSH.UserKey.Marshal | toString | b64enc}} +{{- range .Step.SSH.UserFederatedKeys}} +{{.Type}} {{.Marshal | toString | b64enc}} +{{- end}} \ No newline at end of file diff --git a/authority/testdata/templates/known_hosts.tpl b/authority/testdata/templates/known_hosts.tpl new file mode 100644 index 00000000..acc0fafe --- /dev/null +++ b/authority/testdata/templates/known_hosts.tpl @@ -0,0 +1,4 @@ +@cert-authority * {{.Step.SSH.HostKey.Type}} {{.Step.SSH.HostKey.Marshal | toString | b64enc}} +{{- range .Step.SSH.HostFederatedKeys}} +@cert-authority * {{.Type}} {{.Marshal | toString | b64enc}} +{{- end}} \ No newline at end of file