Introduce generalized statusCoder errors and loads of ssh unit tests.

* StatusCoder api errors that have friendly user messages.
* Unit tests for SSH sign/renew/rekey/revoke across all provisioners.
This commit is contained in:
max furman 2019-12-20 13:30:05 -08:00
parent 3ce267cdd6
commit c387b21808
75 changed files with 5292 additions and 2201 deletions

View file

@ -16,12 +16,12 @@ import (
"testing"
"time"
"github.com/smallstep/certificates/errs"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/crypto/x509util"
"golang.org/x/crypto/ssh"
)
@ -154,18 +154,17 @@ func equalJSON(t *testing.T, a interface{}, b interface{}) bool {
func TestClient_Version(t *testing.T) {
ok := &api.VersionResponse{Version: "test"}
internal := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", ok, 200, false},
{"500", internal, 500, true},
{"404", notFound, 404, true},
{"ok", ok, 200, false, nil},
{"500", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
{"404", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -185,7 +184,6 @@ func TestClient_Version(t *testing.T) {
got, err := c.Version()
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr)
return
}
@ -195,9 +193,7 @@ func TestClient_Version(t *testing.T) {
if got != nil {
t.Errorf("Client.Version() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Version() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Version() = %v, want %v", got, tt.response)
@ -209,16 +205,16 @@ func TestClient_Version(t *testing.T) {
func TestClient_Health(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"}
nok := errs.InternalServerError(fmt.Errorf("Internal Server Error"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", ok, 200, false},
{"not ok", nok, 500, true},
{"ok", ok, 200, false, nil},
{"not ok", errs.InternalServerError(errors.New("force")), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -248,9 +244,7 @@ func TestClient_Health(t *testing.T) {
if got != nil {
t.Errorf("Client.Health() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Health() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Health() = %v, want %v", got, tt.response)
@ -264,7 +258,6 @@ func TestClient_Root(t *testing.T) {
ok := &api.RootResponse{
RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
}
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct {
name string
@ -272,9 +265,10 @@ func TestClient_Root(t *testing.T) {
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false},
{"not found", "invalid", notFound, 404, true},
{"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil},
{"not found", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -307,9 +301,7 @@ func TestClient_Root(t *testing.T) {
if got != nil {
t.Errorf("Client.Root() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Root() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Root() = %v, want %v", got, tt.response)
@ -334,8 +326,6 @@ func TestClient_Sign(t *testing.T) {
NotBefore: api.NewTimeDuration(time.Now()),
NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)),
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
@ -343,11 +333,12 @@ func TestClient_Sign(t *testing.T) {
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", request, ok, 200, false},
{"unauthorized", request, unauthorized, 401, true},
{"empty request", &api.SignRequest{}, badRequest, 403, true},
{"nil request", nil, badRequest, 403, true},
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.SignRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -364,7 +355,9 @@ func TestClient_Sign(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest)
if err := api.ReadJSON(req.Body, body); err != nil {
api.WriteError(w, badRequest)
e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
@ -390,9 +383,7 @@ func TestClient_Sign(t *testing.T) {
if got != nil {
t.Errorf("Client.Sign() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Sign() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
@ -409,19 +400,17 @@ func TestClient_Revoke(t *testing.T) {
OTT: "the-ott",
ReasonCode: 4,
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
request *api.RevokeRequest
response interface{}
responseCode int
wantErr bool
expectedErr error
}{
{"ok", request, ok, 200, false},
{"unauthorized", request, unauthorized, 401, true},
{"nil request", nil, badRequest, 403, true},
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"nil request", nil, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -438,7 +427,9 @@ func TestClient_Revoke(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest)
if err := api.ReadJSON(req.Body, body); err != nil {
api.WriteError(w, badRequest)
e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
@ -464,9 +455,7 @@ func TestClient_Revoke(t *testing.T) {
if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Revoke() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
@ -485,19 +474,18 @@ func TestClient_Renew(t *testing.T) {
{Certificate: parseCertificate(rootPEM)},
},
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -527,9 +515,11 @@ func TestClient_Renew(t *testing.T) {
if got != nil {
t.Errorf("Client.Renew() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Renew() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -589,9 +579,7 @@ func TestClient_Provisioners(t *testing.T) {
if got != nil {
t.Errorf("Client.Provisioners() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Provisioners() error = %v, want %v", err, tt.response)
}
assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
@ -605,7 +593,6 @@ func TestClient_ProvisionerKey(t *testing.T) {
ok := &api.ProvisionerKeyResponse{
Key: "an encrypted key",
}
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct {
name string
@ -613,9 +600,10 @@ func TestClient_ProvisionerKey(t *testing.T) {
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", "kid", ok, 200, false},
{"fail", "invalid", notFound, 500, true},
{"ok", "kid", ok, 200, false, nil},
{"fail", "invalid", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -648,9 +636,11 @@ func TestClient_ProvisionerKey(t *testing.T) {
if got != nil {
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.ProvisionerKey() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
@ -666,19 +656,17 @@ func TestClient_Roots(t *testing.T) {
{Certificate: parseCertificate(rootPEM)},
},
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"bad-request", errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -708,9 +696,10 @@ func TestClient_Roots(t *testing.T) {
if got != nil {
t.Errorf("Client.Roots() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Roots() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
@ -726,19 +715,16 @@ func TestClient_Federation(t *testing.T) {
{Certificate: parseCertificate(rootPEM)},
},
}
unauthorized := errs.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized(errors.New("force")), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -768,9 +754,10 @@ func TestClient_Federation(t *testing.T) {
if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Federation() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
@ -790,16 +777,16 @@ func TestClient_SSHRoots(t *testing.T) {
HostKeys: []api.SSHPublicKey{{PublicKey: key}},
UserKeys: []api.SSHPublicKey{{PublicKey: key}},
}
notFound := errs.NotFound(fmt.Errorf("Not Found"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false},
{"not found", notFound, 404, true},
{"ok", ok, 200, false, nil},
{"not found", errs.NotFound(errors.New("force")), 404, true, errors.New(errs.NotFoundDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -829,9 +816,10 @@ func TestClient_SSHRoots(t *testing.T) {
if got != nil {
t.Errorf("Client.SSHKeys() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.SSHKeys() error = %v, want %v", err, tt.response)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response)
@ -948,7 +936,6 @@ func TestClient_SSHBastion(t *testing.T) {
Hostname: "bastion.local",
},
}
badRequest := errs.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
@ -956,11 +943,11 @@ func TestClient_SSHBastion(t *testing.T) {
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false},
{"bad response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true},
{"empty request", &api.SSHBastionRequest{}, badRequest, 403, true},
{"nil request", nil, badRequest, 403, true},
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
{"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest(errors.New("force")), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
@ -990,8 +977,11 @@ func TestClient_SSHBastion(t *testing.T) {
if got != nil {
t.Errorf("Client.SSHBastion() = %v, want nil", got)
}
if tt.responseCode != 200 && !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.SSHBastion() error = %v, want %v", err, tt.response)
if tt.responseCode != 200 {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
}
default:
if !reflect.DeepEqual(got, tt.response) {