diff --git a/api/api_test.go b/api/api_test.go index f1d90db7..e42de188 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -516,6 +516,7 @@ type mockAuthority struct { getSSHRoots func() (*authority.SSHKeys, error) getSSHFederation func() (*authority.SSHKeys, error) getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error) + checkSSHHost func(principal string) (bool, error) } // TODO: remove once Authorize is deprecated. @@ -642,6 +643,13 @@ func (m *mockAuthority) GetSSHConfig(typ string, data map[string]string) ([]temp return m.ret1.([]templates.Output), m.err } +func (m *mockAuthority) CheckSSHHost(principal string) (bool, error) { + if m.checkSSHHost != nil { + return m.checkSSHHost(principal) + } + return m.ret1.(bool), m.err +} + func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority diff --git a/api/ssh.go b/api/ssh.go index d5ad735a..e3101b8b 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -295,7 +295,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHFederation() if err != nil { - WriteError(w, NotFound(err)) + WriteError(w, InternalServerError(err)) return } diff --git a/api/ssh_test.go b/api/ssh_test.go index 3f5fcdbb..ed107b6c 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -12,6 +12,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strings" "testing" "time" @@ -19,6 +20,7 @@ import ( "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/templates" "golang.org/x/crypto/ssh" ) @@ -337,6 +339,7 @@ func Test_caHandler_SSHRoots(t *testing.T) { statusCode int }{ {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK}, + {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK}, {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK}, {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, @@ -350,25 +353,249 @@ func Test_caHandler_SSHRoots(t *testing.T) { }, }).(*caHandler) - req := httptest.NewRequest("GET", "http://example.com/ssh/keys", http.NoBody) + req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody) w := httptest.NewRecorder() h.SSHRoots(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.SSHKeys StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { - t.Errorf("caHandler.SSHKeys unexpected error = %v", err) + t.Errorf("caHandler.SSHRoots unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.body) { - t.Errorf("caHandler.SSHKeys Body = %s, wants %s", body, tt.body) + t.Errorf("caHandler.SSHRoots Body = %s, wants %s", body, tt.body) } } }) } } + +func Test_caHandler_SSHFederation(t *testing.T) { + user, err := ssh.NewPublicKey(sshUserKey.Public()) + assert.FatalError(t, err) + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + + host, err := ssh.NewPublicKey(sshHostKey.Public()) + assert.FatalError(t, err) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + + tests := []struct { + name string + keys *authority.SSHKeys + keysErr error + body []byte + statusCode int + }{ + {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK}, + {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, + {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK}, + {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK}, + {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, + {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ + getSSHFederation: func() (*authority.SSHKeys, error) { + return tt.keys, tt.keysErr + }, + }).(*caHandler) + + req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody) + w := httptest.NewRecorder() + h.SSHFederation(logging.NewResponseLogger(w), req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.SSHFederation unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), tt.body) { + t.Errorf("caHandler.SSHFederation Body = %s, wants %s", body, tt.body) + } + } + }) + } +} + +func Test_caHandler_SSHConfig(t *testing.T) { + userOutput := []templates.Output{ + {"config.tpl", templates.File, "#", "ssh/config", []byte("UserKnownHostsFile /home/user/.step/config/ssh/known_hosts")}, + {"known_host.tpl", templates.File, "#", "ssh/known_host", []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")}, + } + hostOutput := []templates.Output{ + {"sshd_config.tpl", templates.Snippet, "#", "/etc/ssh/sshd_config", []byte("TrustedUserCAKeys /etc/ssh/ca.pub")}, + {"ca.tpl", templates.File, "#", "/etc/ssh/ca.pub", []byte("ecdsa-sha2-nistp256 AAAA...=")}, + } + userJSON, err := json.Marshal(userOutput) + assert.FatalError(t, err) + hostJSON, err := json.Marshal(hostOutput) + assert.FatalError(t, err) + + tests := []struct { + name string + req string + output []templates.Output + err error + body []byte + statusCode int + }{ + {"user", `{"type":"user"}`, userOutput, nil, []byte(fmt.Sprintf(`{"userTemplates":%s}`, userJSON)), http.StatusOK}, + {"host", `{"type":"host"}`, hostOutput, nil, []byte(fmt.Sprintf(`{"hostTemplates":%s}`, hostJSON)), http.StatusOK}, + {"noType", `{}`, userOutput, nil, []byte(fmt.Sprintf(`{"userTemplates":%s}`, userJSON)), http.StatusOK}, + {"badType", `{"type":"bad"}`, userOutput, nil, nil, http.StatusBadRequest}, + {"badData", `{"type":"user","data":{"bad"}}`, userOutput, nil, nil, http.StatusBadRequest}, + {"error", `{"type": "user"}`, nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ + getSSHConfig: func(typ string, data map[string]string) ([]templates.Output, error) { + return tt.output, tt.err + }, + }).(*caHandler) + + req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req)) + w := httptest.NewRecorder() + h.SSHConfig(logging.NewResponseLogger(w), req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.SSHConfig unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), tt.body) { + t.Errorf("caHandler.SSHConfig Body = %s, wants %s", body, tt.body) + } + } + }) + } +} + +func Test_caHandler_SSHCheckHost(t *testing.T) { + tests := []struct { + name string + req string + exists bool + err error + body []byte + statusCode int + }{ + {"true", `{"type":"host","principal":"foo.example.com"}`, true, nil, []byte(`{"exists":true}`), http.StatusOK}, + {"false", `{"type":"host","principal":"bar.example.com"}`, false, nil, []byte(`{"exists":false}`), http.StatusOK}, + {"badType", `{"type":"user","principal":"bar.example.com"}`, false, nil, nil, http.StatusBadRequest}, + {"badPrincipal", `{"type":"host","principal":""}`, false, nil, nil, http.StatusBadRequest}, + {"badRequest", `{"foo"}`, false, nil, nil, http.StatusBadRequest}, + {"error", `{"type":"host","principal":"foo.example.com"}`, false, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ + checkSSHHost: func(_ string) (bool, error) { + return tt.exists, tt.err + }, + }).(*caHandler) + + req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req)) + w := httptest.NewRecorder() + h.SSHCheckHost(logging.NewResponseLogger(w), req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), tt.body) { + t.Errorf("caHandler.SSHCheckHost Body = %s, wants %s", body, tt.body) + } + } + }) + } +} + +func TestSSHPublicKey_MarshalJSON(t *testing.T) { + key, err := ssh.NewPublicKey(sshUserKey.Public()) + assert.FatalError(t, err) + keyB64 := base64.StdEncoding.EncodeToString(key.Marshal()) + + tests := []struct { + name string + publicKey *SSHPublicKey + want []byte + wantErr bool + }{ + {"ok", &SSHPublicKey{PublicKey: key}, []byte(`"` + keyB64 + `"`), false}, + {"null", nil, []byte("null"), false}, + {"null", &SSHPublicKey{PublicKey: nil}, []byte("null"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.publicKey.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("SSHPublicKey.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SSHPublicKey.MarshalJSON() = %s, want %s", got, tt.want) + } + }) + } +} + +func TestSSHPublicKey_UnmarshalJSON(t *testing.T) { + key, err := ssh.NewPublicKey(sshUserKey.Public()) + assert.FatalError(t, err) + keyB64 := base64.StdEncoding.EncodeToString(key.Marshal()) + + type args struct { + data []byte + } + tests := []struct { + name string + args args + want *SSHPublicKey + wantErr bool + }{ + {"ok", args{[]byte(`"` + keyB64 + `"`)}, &SSHPublicKey{PublicKey: key}, false}, + {"empty", args{[]byte(`""`)}, &SSHPublicKey{}, false}, + {"null", args{[]byte(`null`)}, &SSHPublicKey{}, false}, + {"noString", args{[]byte("123")}, &SSHPublicKey{}, true}, + {"badB64", args{[]byte(`"bad"`)}, &SSHPublicKey{}, true}, + {"badKey", args{[]byte(`"Zm9vYmFyCg=="`)}, &SSHPublicKey{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &SSHPublicKey{} + if err := p.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("SSHPublicKey.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(p, tt.want) { + t.Errorf("SSHPublicKey.UnmarshalJSON() = %v, want %v", p, tt.want) + } + }) + } +}