Add some ssh related tests.

This commit is contained in:
Mariano Cano 2019-10-14 17:10:47 -07:00 committed by max furman
parent 385bf0a14a
commit 4f06f3901e
4 changed files with 200 additions and 8 deletions

View file

@ -4,17 +4,20 @@ import (
"crypto/x509" "crypto/x509"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"golang.org/x/crypto/ssh"
) )
type MockAuthDB struct { type MockAuthDB struct {
err error err error
ret1 interface{} ret1 interface{}
init func(*db.Config) (db.AuthDB, error) init func(*db.Config) (db.AuthDB, error)
isRevoked func(string) (bool, error) isRevoked func(string) (bool, error)
revoke func(rci *db.RevokedCertificateInfo) error revoke func(rci *db.RevokedCertificateInfo) error
storeCertificate func(crt *x509.Certificate) error storeCertificate func(crt *x509.Certificate) error
useToken func(id, tok string) (bool, error) useToken func(id, tok string) (bool, error)
shutdown func() error isSSHHost func(principal string) (bool, error)
storeSSHCertificate func(crt *ssh.Certificate) error
shutdown func() error
} }
func (m *MockAuthDB) Init(c *db.Config) (db.AuthDB, error) { func (m *MockAuthDB) Init(c *db.Config) (db.AuthDB, error) {
@ -58,6 +61,20 @@ func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
return m.err return m.err
} }
func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) {
if m.isSSHHost != nil {
return m.isSSHHost(principal)
}
return m.ret1.(bool), m.err
}
func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error {
if m.storeSSHCertificate != nil {
return m.storeSSHCertificate(crt)
}
return m.err
}
func (m *MockAuthDB) Shutdown() error { func (m *MockAuthDB) Shutdown() error {
if m.shutdown != nil { if m.shutdown != nil {
return m.shutdown() return m.shutdown()

View file

@ -4,12 +4,15 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"encoding/base64"
"fmt" "fmt"
"reflect"
"testing" "testing"
"time" "time"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -253,3 +256,167 @@ func TestAuthority_SignSSHAddUser(t *testing.T) {
}) })
} }
} }
func TestAuthority_GetSSHRoots(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
user, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
host, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
type fields struct {
sshCAUserCerts []ssh.PublicKey
sshCAHostCerts []ssh.PublicKey
}
tests := []struct {
name string
fields fields
want *SSHKeys
wantErr bool
}{
{"ok", fields{[]ssh.PublicKey{user}, []ssh.PublicKey{host}}, &SSHKeys{UserKeys: []ssh.PublicKey{user}, HostKeys: []ssh.PublicKey{host}}, false},
{"nil", fields{}, &SSHKeys{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
a.sshCAUserCerts = tt.fields.sshCAUserCerts
a.sshCAHostCerts = tt.fields.sshCAHostCerts
got, err := a.GetSSHRoots()
if (err != nil) != tt.wantErr {
t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetSSHRoots() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthority_GetSSHFederation(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
user, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
host, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
type fields struct {
sshCAUserFederatedCerts []ssh.PublicKey
sshCAHostFederatedCerts []ssh.PublicKey
}
tests := []struct {
name string
fields fields
want *SSHKeys
wantErr bool
}{
{"ok", fields{[]ssh.PublicKey{user}, []ssh.PublicKey{host}}, &SSHKeys{UserKeys: []ssh.PublicKey{user}, HostKeys: []ssh.PublicKey{host}}, false},
{"nil", fields{}, &SSHKeys{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts
a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts
got, err := a.GetSSHFederation()
if (err != nil) != tt.wantErr {
t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetSSHFederation() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthority_GetSSHConfig(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
user, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
userSigner, err := ssh.NewSignerFromSigner(key)
assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err)
host, err := ssh.NewPublicKey(key.Public())
assert.FatalError(t, err)
hostSigner, err := ssh.NewSignerFromSigner(key)
assert.FatalError(t, err)
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
tmplConfig := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
{Name: "known_host.tpl", Type: templates.File, TemplatePath: "./testdata/templates/known_hosts.tpl", Path: "ssh/known_host", Comment: "#"},
},
Host: []templates.Template{
{Name: "ca.tpl", Type: templates.File, TemplatePath: "./testdata/templates/ca.tpl", Path: "/etc/ssh/ca.pub", Comment: "#"},
},
},
Data: map[string]interface{}{
"Step": &templates.Step{
SSH: templates.StepSSH{
UserKey: user,
HostKey: host,
},
},
},
}
userOutput := []templates.Output{
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte(fmt.Sprintf("@cert-authority * %s %s", host.Type(), hostB64))},
}
hostOutput := []templates.Output{
{Name: "ca.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/ca.pub", Content: []byte(user.Type() + " " + userB64)},
}
type fields struct {
templates *templates.Templates
userSigner ssh.Signer
hostSigner ssh.Signer
}
type args struct {
typ string
data map[string]string
}
tests := []struct {
name string
fields fields
args args
want []templates.Output
wantErr bool
}{
{"user", fields{tmplConfig, userSigner, hostSigner}, args{"user", nil}, userOutput, false},
{"host", fields{tmplConfig, userSigner, hostSigner}, args{"host", nil}, hostOutput, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
a.config.Templates = tt.fields.templates
a.sshCAUserCertSignKey = tt.fields.userSigner
a.sshCAHostCertSignKey = tt.fields.hostSigner
got, err := a.GetSSHConfig(tt.args.typ, tt.args.data)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetSSHConfig() = %v, want %v", got, tt.want)
}
})
}
}

4
authority/testdata/templates/ca.tpl vendored Normal file
View file

@ -0,0 +1,4 @@
{{.Step.SSH.UserKey.Type}} {{.Step.SSH.UserKey.Marshal | toString | b64enc}}
{{- range .Step.SSH.UserFederatedKeys}}
{{.Type}} {{.Marshal | toString | b64enc}}
{{- end}}

View file

@ -0,0 +1,4 @@
@cert-authority * {{.Step.SSH.HostKey.Type}} {{.Step.SSH.HostKey.Marshal | toString | b64enc}}
{{- range .Step.SSH.HostFederatedKeys}}
@cert-authority * {{.Type}} {{.Marshal | toString | b64enc}}
{{- end}}