Add tests for ssh api methods.
This commit is contained in:
parent
dc24f27647
commit
c563143e6d
3 changed files with 240 additions and 5 deletions
|
@ -516,6 +516,7 @@ type mockAuthority struct {
|
||||||
getSSHRoots func() (*authority.SSHKeys, error)
|
getSSHRoots func() (*authority.SSHKeys, error)
|
||||||
getSSHFederation func() (*authority.SSHKeys, error)
|
getSSHFederation func() (*authority.SSHKeys, error)
|
||||||
getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error)
|
getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error)
|
||||||
|
checkSSHHost func(principal string) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove once Authorize is deprecated.
|
// TODO: remove once Authorize is deprecated.
|
||||||
|
@ -642,6 +643,13 @@ func (m *mockAuthority) GetSSHConfig(typ string, data map[string]string) ([]temp
|
||||||
return m.ret1.([]templates.Output), m.err
|
return m.ret1.([]templates.Output), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthority) CheckSSHHost(principal string) (bool, error) {
|
||||||
|
if m.checkSSHHost != nil {
|
||||||
|
return m.checkSSHHost(principal)
|
||||||
|
}
|
||||||
|
return m.ret1.(bool), m.err
|
||||||
|
}
|
||||||
|
|
||||||
func Test_caHandler_Route(t *testing.T) {
|
func Test_caHandler_Route(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Authority Authority
|
Authority Authority
|
||||||
|
|
|
@ -295,7 +295,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||||
keys, err := h.Authority.GetSSHFederation()
|
keys, err := h.Authority.GetSSHFederation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, NotFound(err))
|
WriteError(w, InternalServerError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
235
api/ssh_test.go
235
api/ssh_test.go
|
@ -12,6 +12,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -19,6 +20,7 @@ import (
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
|
"github.com/smallstep/certificates/templates"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -337,6 +339,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
||||||
statusCode int
|
statusCode int
|
||||||
}{
|
}{
|
||||||
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK},
|
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK},
|
||||||
|
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
|
||||||
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK},
|
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK},
|
||||||
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK},
|
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK},
|
||||||
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
|
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
|
||||||
|
@ -350,25 +353,249 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
}).(*caHandler)
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/keys", http.NoBody)
|
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHRoots(logging.NewResponseLogger(w), req)
|
h.SSHRoots(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
t.Errorf("caHandler.SSHKeys StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(res.Body)
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("caHandler.SSHKeys unexpected error = %v", err)
|
t.Errorf("caHandler.SSHRoots unexpected error = %v", err)
|
||||||
}
|
}
|
||||||
if tt.statusCode < http.StatusBadRequest {
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||||
t.Errorf("caHandler.SSHKeys Body = %s, wants %s", body, tt.body)
|
t.Errorf("caHandler.SSHRoots Body = %s, wants %s", body, tt.body)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_caHandler_SSHFederation(t *testing.T) {
|
||||||
|
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
|
||||||
|
host, err := ssh.NewPublicKey(sshHostKey.Public())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keys *authority.SSHKeys
|
||||||
|
keysErr error
|
||||||
|
body []byte
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK},
|
||||||
|
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
|
||||||
|
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK},
|
||||||
|
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK},
|
||||||
|
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
|
||||||
|
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := New(&mockAuthority{
|
||||||
|
getSSHFederation: func() (*authority.SSHKeys, error) {
|
||||||
|
return tt.keys, tt.keysErr
|
||||||
|
},
|
||||||
|
}).(*caHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SSHFederation(logging.NewResponseLogger(w), req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("caHandler.SSHFederation unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||||
|
t.Errorf("caHandler.SSHFederation Body = %s, wants %s", body, tt.body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_caHandler_SSHConfig(t *testing.T) {
|
||||||
|
userOutput := []templates.Output{
|
||||||
|
{"config.tpl", templates.File, "#", "ssh/config", []byte("UserKnownHostsFile /home/user/.step/config/ssh/known_hosts")},
|
||||||
|
{"known_host.tpl", templates.File, "#", "ssh/known_host", []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
|
||||||
|
}
|
||||||
|
hostOutput := []templates.Output{
|
||||||
|
{"sshd_config.tpl", templates.Snippet, "#", "/etc/ssh/sshd_config", []byte("TrustedUserCAKeys /etc/ssh/ca.pub")},
|
||||||
|
{"ca.tpl", templates.File, "#", "/etc/ssh/ca.pub", []byte("ecdsa-sha2-nistp256 AAAA...=")},
|
||||||
|
}
|
||||||
|
userJSON, err := json.Marshal(userOutput)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
hostJSON, err := json.Marshal(hostOutput)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req string
|
||||||
|
output []templates.Output
|
||||||
|
err error
|
||||||
|
body []byte
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"user", `{"type":"user"}`, userOutput, nil, []byte(fmt.Sprintf(`{"userTemplates":%s}`, userJSON)), http.StatusOK},
|
||||||
|
{"host", `{"type":"host"}`, hostOutput, nil, []byte(fmt.Sprintf(`{"hostTemplates":%s}`, hostJSON)), http.StatusOK},
|
||||||
|
{"noType", `{}`, userOutput, nil, []byte(fmt.Sprintf(`{"userTemplates":%s}`, userJSON)), http.StatusOK},
|
||||||
|
{"badType", `{"type":"bad"}`, userOutput, nil, nil, http.StatusBadRequest},
|
||||||
|
{"badData", `{"type":"user","data":{"bad"}}`, userOutput, nil, nil, http.StatusBadRequest},
|
||||||
|
{"error", `{"type": "user"}`, nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := New(&mockAuthority{
|
||||||
|
getSSHConfig: func(typ string, data map[string]string) ([]templates.Output, error) {
|
||||||
|
return tt.output, tt.err
|
||||||
|
},
|
||||||
|
}).(*caHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SSHConfig(logging.NewResponseLogger(w), req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("caHandler.SSHConfig unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||||
|
t.Errorf("caHandler.SSHConfig Body = %s, wants %s", body, tt.body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req string
|
||||||
|
exists bool
|
||||||
|
err error
|
||||||
|
body []byte
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"true", `{"type":"host","principal":"foo.example.com"}`, true, nil, []byte(`{"exists":true}`), http.StatusOK},
|
||||||
|
{"false", `{"type":"host","principal":"bar.example.com"}`, false, nil, []byte(`{"exists":false}`), http.StatusOK},
|
||||||
|
{"badType", `{"type":"user","principal":"bar.example.com"}`, false, nil, nil, http.StatusBadRequest},
|
||||||
|
{"badPrincipal", `{"type":"host","principal":""}`, false, nil, nil, http.StatusBadRequest},
|
||||||
|
{"badRequest", `{"foo"}`, false, nil, nil, http.StatusBadRequest},
|
||||||
|
{"error", `{"type":"host","principal":"foo.example.com"}`, false, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := New(&mockAuthority{
|
||||||
|
checkSSHHost: func(_ string) (bool, error) {
|
||||||
|
return tt.exists, tt.err
|
||||||
|
},
|
||||||
|
}).(*caHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||||
|
t.Errorf("caHandler.SSHCheckHost Body = %s, wants %s", body, tt.body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHPublicKey_MarshalJSON(t *testing.T) {
|
||||||
|
key, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
keyB64 := base64.StdEncoding.EncodeToString(key.Marshal())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
publicKey *SSHPublicKey
|
||||||
|
want []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", &SSHPublicKey{PublicKey: key}, []byte(`"` + keyB64 + `"`), false},
|
||||||
|
{"null", nil, []byte("null"), false},
|
||||||
|
{"null", &SSHPublicKey{PublicKey: nil}, []byte("null"), false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.publicKey.MarshalJSON()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SSHPublicKey.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SSHPublicKey.MarshalJSON() = %s, want %s", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHPublicKey_UnmarshalJSON(t *testing.T) {
|
||||||
|
key, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
keyB64 := base64.StdEncoding.EncodeToString(key.Marshal())
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *SSHPublicKey
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{[]byte(`"` + keyB64 + `"`)}, &SSHPublicKey{PublicKey: key}, false},
|
||||||
|
{"empty", args{[]byte(`""`)}, &SSHPublicKey{}, false},
|
||||||
|
{"null", args{[]byte(`null`)}, &SSHPublicKey{}, false},
|
||||||
|
{"noString", args{[]byte("123")}, &SSHPublicKey{}, true},
|
||||||
|
{"badB64", args{[]byte(`"bad"`)}, &SSHPublicKey{}, true},
|
||||||
|
{"badKey", args{[]byte(`"Zm9vYmFyCg=="`)}, &SSHPublicKey{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &SSHPublicKey{}
|
||||||
|
if err := p.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SSHPublicKey.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(p, tt.want) {
|
||||||
|
t.Errorf("SSHPublicKey.UnmarshalJSON() = %v, want %v", p, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue