From ca74bb1de52dd993c06c39dc94a78ea7003ec213 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 5 Aug 2019 16:06:05 -0700 Subject: [PATCH] Add ssh api tests. --- api/api_test.go | 11 +- api/ssh.go | 16 +-- api/ssh_test.go | 327 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 344 insertions(+), 10 deletions(-) create mode 100644 api/ssh_test.go diff --git a/api/api_test.go b/api/api_test.go index 7a3c843d..4fb980e2 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -23,14 +23,13 @@ import ( "testing" "time" - "golang.org/x/crypto/ssh" - "github.com/go-chi/chi" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/jose" + "golang.org/x/crypto/ssh" ) const ( @@ -498,6 +497,7 @@ type mockAuthority struct { root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) singSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + singSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) @@ -547,6 +547,13 @@ func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, return m.ret1.(*ssh.Certificate), m.err } +func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.singSSHAddUser != nil { + return m.singSSHAddUser(key, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) { if m.renew != nil { return m.renew(cert) diff --git a/api/ssh.go b/api/ssh.go index 92deff5e..a847c9a0 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -30,13 +30,13 @@ type SignSSHRequest struct { // SignSSHResponse is the response object that returns the SSH certificate. type SignSSHResponse struct { - Certificate SSHCertificate `json:"crt"` - AddUserCertificate SSHCertificate `json:"addUserCrt"` + Certificate SSHCertificate `json:"crt"` + AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"` } // SSHCertificate represents the response SSH certificate. type SSHCertificate struct { - *ssh.Certificate + *ssh.Certificate `json:"omitempty"` } // MarshalJSON implements the json.Marshaler interface. The certificate is @@ -102,7 +102,7 @@ func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) { logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + WriteError(w, BadRequest(err)) return } @@ -141,19 +141,19 @@ func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) { return } - var addUserCert *ssh.Certificate + var addUserCertificate *SSHCertificate if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 { - addUserCert, err = h.Authority.SignSSHAddUser(addUserPublicKey, cert) + addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert) if err != nil { WriteError(w, Forbidden(err)) return } + addUserCertificate = &SSHCertificate{addUserCert} } w.WriteHeader(http.StatusCreated) - // logCertificate(w, cert) JSON(w, &SignSSHResponse{ Certificate: SSHCertificate{cert}, - AddUserCertificate: SSHCertificate{addUserCert}, + AddUserCertificate: addUserCertificate, }) } diff --git a/api/ssh_test.go b/api/ssh_test.go new file mode 100644 index 00000000..df3edf31 --- /dev/null +++ b/api/ssh_test.go @@ -0,0 +1,327 @@ +package api + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/logging" + "golang.org/x/crypto/ssh" +) + +var ( + sshSignerKey = mustKey() + sshUserKey = mustKey() + sshHostKey = mustKey() +) + +func mustKey() *ecdsa.PrivateKey { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + return priv +} + +func signSSHCertificate(cert *ssh.Certificate) error { + signerKey, err := ssh.NewPublicKey(sshSignerKey.Public()) + if err != nil { + return err + } + signer, err := ssh.NewSignerFromSigner(sshSignerKey) + if err != nil { + return err + } + cert.SignatureKey = signerKey + data := cert.Marshal() + data = data[:len(data)-4] + sig, err := signer.Sign(rand.Reader, data) + if err != nil { + return err + } + cert.Signature = sig + return nil +} + +func getSignedUserCertificate() (*ssh.Certificate, error) { + key, err := ssh.NewPublicKey(sshUserKey.Public()) + if err != nil { + return nil, err + } + t := time.Now() + cert := &ssh.Certificate{ + Nonce: []byte("1234567890"), + Key: key, + Serial: 1234567890, + CertType: ssh.UserCert, + KeyId: "user@localhost", + ValidPrincipals: []string{"user"}, + ValidAfter: uint64(t.Unix()), + ValidBefore: uint64(t.Add(time.Hour).Unix()), + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + }, + }, + Reserved: []byte{}, + } + if err := signSSHCertificate(cert); err != nil { + return nil, err + } + return cert, nil +} + +func getSignedHostCertificate() (*ssh.Certificate, error) { + key, err := ssh.NewPublicKey(sshHostKey.Public()) + if err != nil { + return nil, err + } + t := time.Now() + cert := &ssh.Certificate{ + Nonce: []byte("1234567890"), + Key: key, + Serial: 1234567890, + CertType: ssh.UserCert, + KeyId: "internal.smallstep.com", + ValidPrincipals: []string{"internal.smallstep.com"}, + ValidAfter: uint64(t.Unix()), + ValidBefore: uint64(t.Add(time.Hour).Unix()), + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + Reserved: []byte{}, + } + if err := signSSHCertificate(cert); err != nil { + return nil, err + } + return cert, nil +} + +func TestSSHCertificate_MarshalJSON(t *testing.T) { + user, err := getSignedUserCertificate() + assert.FatalError(t, err) + host, err := getSignedHostCertificate() + assert.FatalError(t, err) + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + + type fields struct { + Certificate *ssh.Certificate + } + tests := []struct { + name string + fields fields + want []byte + wantErr bool + }{ + {"nil", fields{Certificate: nil}, []byte("null"), false}, + {"user", fields{Certificate: user}, []byte(`"` + userB64 + `"`), false}, + {"user", fields{Certificate: host}, []byte(`"` + hostB64 + `"`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := SSHCertificate{ + Certificate: tt.fields.Certificate, + } + got, err := c.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("SSHCertificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SSHCertificate.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSSHCertificate_UnmarshalJSON(t *testing.T) { + user, err := getSignedUserCertificate() + assert.FatalError(t, err) + host, err := getSignedHostCertificate() + assert.FatalError(t, err) + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal()) + + type args struct { + data []byte + } + tests := []struct { + name string + args args + want *ssh.Certificate + wantErr bool + }{ + {"null", args{[]byte(`null`)}, nil, false}, + {"empty", args{[]byte(`""`)}, nil, false}, + {"user", args{[]byte(`"` + userB64 + `"`)}, user, false}, + {"host", args{[]byte(`"` + hostB64 + `"`)}, host, false}, + {"bad-string", args{[]byte(userB64)}, nil, true}, + {"bad-base64", args{[]byte(`"this-is-not-base64"`)}, nil, true}, + {"bad-key", args{[]byte(`"bm90LWEta2V5"`)}, nil, true}, + {"bat-cert", args{[]byte(`"` + keyB64 + `"`)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &SSHCertificate{} + if err := c.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("SSHCertificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(tt.want, c.Certificate) { + t.Errorf("SSHCertificate.UnmarshalJSON() got = %v, want %v\n", c.Certificate, tt.want) + } + }) + } +} + +func TestSignSSHRequest_Validate(t *testing.T) { + type fields struct { + PublicKey []byte + OTT string + CertType string + Principals []string + ValidAfter TimeDuration + ValidBefore TimeDuration + AddUserPublicKey []byte + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false}, + {"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false}, + {"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false}, + {"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, + {"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, + {"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, + {"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &SignSSHRequest{ + PublicKey: tt.fields.PublicKey, + OTT: tt.fields.OTT, + CertType: tt.fields.CertType, + Principals: tt.fields.Principals, + ValidAfter: tt.fields.ValidAfter, + ValidBefore: tt.fields.ValidBefore, + AddUserPublicKey: tt.fields.AddUserPublicKey, + } + if err := s.Validate(); (err != nil) != tt.wantErr { + t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_caHandler_SignSSH(t *testing.T) { + user, err := getSignedUserCertificate() + assert.FatalError(t, err) + host, err := getSignedHostCertificate() + assert.FatalError(t, err) + + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + + userReq, err := json.Marshal(SignSSHRequest{ + PublicKey: user.Key.Marshal(), + OTT: "ott", + }) + assert.FatalError(t, err) + hostReq, err := json.Marshal(SignSSHRequest{ + PublicKey: host.Key.Marshal(), + OTT: "ott", + }) + assert.FatalError(t, err) + userAddReq, err := json.Marshal(SignSSHRequest{ + PublicKey: user.Key.Marshal(), + OTT: "ott", + AddUserPublicKey: user.Key.Marshal(), + }) + assert.FatalError(t, err) + + type fields struct { + Authority Authority + } + type args struct { + w http.ResponseWriter + r *http.Request + } + tests := []struct { + name string + req []byte + authErr error + signCert *ssh.Certificate + signErr error + addUserCert *ssh.Certificate + addUserErr error + body []byte + statusCode int + }{ + {"ok-user", userReq, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated}, + {"ok-host", hostReq, nil, host, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated}, + {"ok-user-add", userAddReq, nil, user, nil, user, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated}, + {"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusUnauthorized}, + {"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden}, + {"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ + authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + return []provisioner.SignOption{}, tt.authErr + }, + singSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + return tt.signCert, tt.signErr + }, + singSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + return tt.addUserCert, tt.addUserErr + }, + }).(*caHandler) + + req := httptest.NewRequest("POST", "http://example.com/sign-ssh", bytes.NewReader(tt.req)) + w := httptest.NewRecorder() + h.SignSSH(logging.NewResponseLogger(w), req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Root unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), tt.body) { + t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.body) + } + } + }) + } +}