forked from TrueCloudLab/certificates
Add missing unit tests for ssh.
This commit is contained in:
parent
a049e1f7e7
commit
15a222d354
1 changed files with 107 additions and 26 deletions
133
api/ssh_test.go
133
api/ssh_test.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -20,6 +21,7 @@ import (
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
|
"github.com/smallstep/certificates/sshutil"
|
||||||
"github.com/smallstep/certificates/templates"
|
"github.com/smallstep/certificates/templates"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
@ -197,6 +199,10 @@ func TestSSHCertificate_UnmarshalJSON(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSignSSHRequest_Validate(t *testing.T) {
|
func TestSignSSHRequest_Validate(t *testing.T) {
|
||||||
|
csr := parseCertificateRequest(csrPEM)
|
||||||
|
badCSR := parseCertificateRequest(csrPEM)
|
||||||
|
badCSR.SignatureAlgorithm = x509.SHA1WithRSA
|
||||||
|
|
||||||
type fields struct {
|
type fields struct {
|
||||||
PublicKey []byte
|
PublicKey []byte
|
||||||
OTT string
|
OTT string
|
||||||
|
@ -205,19 +211,24 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
||||||
ValidAfter TimeDuration
|
ValidAfter TimeDuration
|
||||||
ValidBefore TimeDuration
|
ValidBefore TimeDuration
|
||||||
AddUserPublicKey []byte
|
AddUserPublicKey []byte
|
||||||
|
KeyID string
|
||||||
|
IdentityCSR CertificateRequest
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
fields fields
|
fields fields
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
{"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false},
|
||||||
{"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
{"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false},
|
||||||
{"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
{"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false},
|
||||||
{"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
{"ok-keyID", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{}}, false},
|
||||||
{"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
{"ok-identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: csr}}, false},
|
||||||
{"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
{"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
|
||||||
{"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
{"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
|
||||||
|
{"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
|
||||||
|
{"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
|
||||||
|
{"identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: badCSR}}, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -229,6 +240,8 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
||||||
ValidAfter: tt.fields.ValidAfter,
|
ValidAfter: tt.fields.ValidAfter,
|
||||||
ValidBefore: tt.fields.ValidBefore,
|
ValidBefore: tt.fields.ValidBefore,
|
||||||
AddUserPublicKey: tt.fields.AddUserPublicKey,
|
AddUserPublicKey: tt.fields.AddUserPublicKey,
|
||||||
|
KeyID: tt.fields.KeyID,
|
||||||
|
IdentityCSR: tt.fields.IdentityCSR,
|
||||||
}
|
}
|
||||||
if err := s.Validate(); (err != nil) != tt.wantErr {
|
if err := s.Validate(); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -262,28 +275,42 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
AddUserPublicKey: user.Key.Marshal(),
|
AddUserPublicKey: user.Key.Marshal(),
|
||||||
})
|
})
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
userIdentityReq, err := json.Marshal(SSHSignRequest{
|
||||||
|
PublicKey: user.Key.Marshal(),
|
||||||
|
OTT: "ott",
|
||||||
|
IdentityCSR: CertificateRequest{parseCertificateRequest(csrPEM)},
|
||||||
|
})
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
identityCerts := []*x509.Certificate{
|
||||||
|
parseCertificate(certPEM),
|
||||||
|
}
|
||||||
|
identityCertsPEM := []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
req []byte
|
req []byte
|
||||||
authErr error
|
authErr error
|
||||||
signCert *ssh.Certificate
|
signCert *ssh.Certificate
|
||||||
signErr error
|
signErr error
|
||||||
addUserCert *ssh.Certificate
|
addUserCert *ssh.Certificate
|
||||||
addUserErr error
|
addUserErr error
|
||||||
body []byte
|
tlsSignCerts []*x509.Certificate
|
||||||
statusCode int
|
tlsSignErr error
|
||||||
|
body []byte
|
||||||
|
statusCode int
|
||||||
}{
|
}{
|
||||||
{"ok-user", userReq, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated},
|
{"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated},
|
||||||
{"ok-host", hostReq, nil, host, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated},
|
{"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated},
|
||||||
{"ok-user-add", userAddReq, nil, user, nil, user, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated},
|
{"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated},
|
||||||
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
{"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":"%s","identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated},
|
||||||
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusUnauthorized},
|
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
|
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized},
|
||||||
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden},
|
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden},
|
||||||
|
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
|
||||||
|
{"fail-user-identity", userIdentityReq, nil, user, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -297,6 +324,9 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
return tt.addUserCert, tt.addUserErr
|
return tt.addUserCert, tt.addUserErr
|
||||||
},
|
},
|
||||||
|
sign: func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||||
|
return tt.tlsSignCerts, tt.tlsSignErr
|
||||||
|
},
|
||||||
}).(*caHandler)
|
}).(*caHandler)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
||||||
|
@ -537,6 +567,57 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_caHandler_SSHGetHosts(t *testing.T) {
|
||||||
|
hosts := []sshutil.Host{
|
||||||
|
{HostID: "1", HostGroups: []sshutil.HostGroup{{ID: "1", Name: "group 1"}}, Hostname: "host1"},
|
||||||
|
{HostID: "2", HostGroups: []sshutil.HostGroup{{ID: "1", Name: "group 1"}, {ID: "2", Name: "group 2"}}, Hostname: "host2"},
|
||||||
|
}
|
||||||
|
hostsJSON, err := json.Marshal(hosts)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hosts []sshutil.Host
|
||||||
|
err error
|
||||||
|
body []byte
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"ok", hosts, nil, []byte(fmt.Sprintf(`{"hosts":%s}`, hostsJSON)), http.StatusOK},
|
||||||
|
{"empty (array)", []sshutil.Host{}, nil, []byte(`{"hosts":[]}`), http.StatusOK},
|
||||||
|
{"empty (nil)", nil, nil, []byte(`{"hosts":null}`), http.StatusOK},
|
||||||
|
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := New(&mockAuthority{
|
||||||
|
getSSHHosts: func(*x509.Certificate) ([]sshutil.Host, error) {
|
||||||
|
return tt.hosts, tt.err
|
||||||
|
},
|
||||||
|
}).(*caHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||||
|
t.Errorf("caHandler.SSHGetHosts Body = %s, wants %s", body, tt.body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHBastion(t *testing.T) {
|
func Test_caHandler_SSHBastion(t *testing.T) {
|
||||||
bastion := &authority.Bastion{
|
bastion := &authority.Bastion{
|
||||||
Hostname: "bastion.local",
|
Hostname: "bastion.local",
|
||||||
|
|
Loading…
Reference in a new issue