diff --git a/api/ssh_test.go b/api/ssh_test.go index cc615ee7..b5ff7002 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -5,6 +5,7 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/x509" "encoding/base64" "encoding/json" "fmt" @@ -20,6 +21,7 @@ import ( "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/sshutil" "github.com/smallstep/certificates/templates" "golang.org/x/crypto/ssh" ) @@ -197,6 +199,10 @@ func TestSSHCertificate_UnmarshalJSON(t *testing.T) { } func TestSignSSHRequest_Validate(t *testing.T) { + csr := parseCertificateRequest(csrPEM) + badCSR := parseCertificateRequest(csrPEM) + badCSR.SignatureAlgorithm = x509.SHA1WithRSA + type fields struct { PublicKey []byte OTT string @@ -205,19 +211,24 @@ func TestSignSSHRequest_Validate(t *testing.T) { ValidAfter TimeDuration ValidBefore TimeDuration AddUserPublicKey []byte + KeyID string + IdentityCSR CertificateRequest } tests := []struct { name string fields fields wantErr bool }{ - {"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false}, - {"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false}, - {"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false}, - {"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, - {"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, - {"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, - {"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, + {"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false}, + {"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false}, + {"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false}, + {"ok-keyID", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{}}, false}, + {"ok-identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: csr}}, false}, + {"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true}, + {"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true}, + {"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true}, + {"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true}, + {"identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: badCSR}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -229,6 +240,8 @@ func TestSignSSHRequest_Validate(t *testing.T) { ValidAfter: tt.fields.ValidAfter, ValidBefore: tt.fields.ValidBefore, AddUserPublicKey: tt.fields.AddUserPublicKey, + KeyID: tt.fields.KeyID, + IdentityCSR: tt.fields.IdentityCSR, } if err := s.Validate(); (err != nil) != tt.wantErr { t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) @@ -262,28 +275,42 @@ func Test_caHandler_SSHSign(t *testing.T) { AddUserPublicKey: user.Key.Marshal(), }) assert.FatalError(t, err) + userIdentityReq, err := json.Marshal(SSHSignRequest{ + PublicKey: user.Key.Marshal(), + OTT: "ott", + IdentityCSR: CertificateRequest{parseCertificateRequest(csrPEM)}, + }) + assert.FatalError(t, err) + identityCerts := []*x509.Certificate{ + parseCertificate(certPEM), + } + identityCertsPEM := []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`) tests := []struct { - name string - req []byte - authErr error - signCert *ssh.Certificate - signErr error - addUserCert *ssh.Certificate - addUserErr error - body []byte - statusCode int + name string + req []byte + authErr error + signCert *ssh.Certificate + signErr error + addUserCert *ssh.Certificate + addUserErr error + tlsSignCerts []*x509.Certificate + tlsSignErr error + body []byte + statusCode int }{ - {"ok-user", userReq, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated}, - {"ok-host", hostReq, nil, host, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated}, - {"ok-user-add", userAddReq, nil, user, nil, user, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated}, - {"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, - {"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, - {"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, - {"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, - {"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusUnauthorized}, - {"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden}, - {"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden}, + {"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated}, + {"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated}, + {"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated}, + {"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":"%s","identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated}, + {"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized}, + {"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden}, + {"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden}, + {"fail-user-identity", userIdentityReq, nil, user, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -297,6 +324,9 @@ func Test_caHandler_SSHSign(t *testing.T) { signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { return tt.addUserCert, tt.addUserErr }, + sign: func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + return tt.tlsSignCerts, tt.tlsSignErr + }, }).(*caHandler) req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) @@ -537,6 +567,57 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } } +func Test_caHandler_SSHGetHosts(t *testing.T) { + hosts := []sshutil.Host{ + {HostID: "1", HostGroups: []sshutil.HostGroup{{ID: "1", Name: "group 1"}}, Hostname: "host1"}, + {HostID: "2", HostGroups: []sshutil.HostGroup{{ID: "1", Name: "group 1"}, {ID: "2", Name: "group 2"}}, Hostname: "host2"}, + } + hostsJSON, err := json.Marshal(hosts) + assert.FatalError(t, err) + + tests := []struct { + name string + hosts []sshutil.Host + err error + body []byte + statusCode int + }{ + {"ok", hosts, nil, []byte(fmt.Sprintf(`{"hosts":%s}`, hostsJSON)), http.StatusOK}, + {"empty (array)", []sshutil.Host{}, nil, []byte(`{"hosts":[]}`), http.StatusOK}, + {"empty (nil)", nil, nil, []byte(`{"hosts":null}`), http.StatusOK}, + {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ + getSSHHosts: func(*x509.Certificate) ([]sshutil.Host, error) { + return tt.hosts, tt.err + }, + }).(*caHandler) + + req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody) + w := httptest.NewRecorder() + h.SSHGetHosts(logging.NewResponseLogger(w), req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), tt.body) { + t.Errorf("caHandler.SSHGetHosts Body = %s, wants %s", body, tt.body) + } + } + }) + } +} + func Test_caHandler_SSHBastion(t *testing.T) { bastion := &authority.Bastion{ Hostname: "bastion.local",