diff --git a/api/api.go b/api/api.go index bc61ad3e..9e6ee301 100644 --- a/api/api.go +++ b/api/api.go @@ -487,7 +487,7 @@ type MockAuthority struct { rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) MockLoadProvisionerByName func(name string) (provisioner.Interface, error) - getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) + MockGetProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) revoke func(context.Context, *authority.RevokeOptions) error getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) @@ -567,8 +567,8 @@ func (m *MockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([ } func (m *MockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { - if m.getProvisioners != nil { - return m.getProvisioners(nextCursor, limit) + if m.MockGetProvisioners != nil { + return m.MockGetProvisioners(nextCursor, limit) } return m.ret1.(provisioner.List), m.ret2.(string), m.err } diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go index 114629fc..15c581f4 100644 --- a/authority/admin/api/acme_test.go +++ b/authority/admin/api/acme_test.go @@ -1154,21 +1154,23 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) { assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) - } else { - body, err := io.ReadAll(res.Body) - res.Body.Close() - assert.FatalError(t, err) - - response := GetExternalAccountKeysResponse{} - assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) - - assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) - - opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.EABKey{}, timestamppb.Timestamp{})} - if !cmp.Equal(tc.resp, response, opts...) { - t.Errorf("h.GetExternalAccountKeys diff =\n%s", cmp.Diff(tc.resp, response, opts...)) - } + return } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + response := GetExternalAccountKeysResponse{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.EABKey{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.resp, response, opts...) { + t.Errorf("h.GetExternalAccountKeys diff =\n%s", cmp.Diff(tc.resp, response, opts...)) + } + }) } } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index fd1a02d5..d111f1e6 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -54,7 +54,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, - "error parsing cursor & limit query params")) + "error parsing cursor and limit from query params")) return } diff --git a/authority/admin/api/provisioner_test.go b/authority/admin/api/provisioner_test.go new file mode 100644 index 00000000..68a54fe8 --- /dev/null +++ b/authority/admin/api/provisioner_test.go @@ -0,0 +1,1105 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestHandler_GetProvisioner(t *testing.T) { + type test struct { + ctx context.Context + auth api.LinkedAuthority + db admin.DB + req *http.Request + statusCode int + err *admin.Error + prov *linkedca.Provisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadProvisionerByID": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo?id=provID", nil) + chiCtx := chi.NewRouteContext() + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByID: func(id string) (provisioner.Interface, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provID: force", + }, + } + }, + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provName: force", + }, + } + }, + "fail/db.GetProvisioner": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.ACME{ + ID: "acmeID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "acmeID", id) + return nil, admin.NewErrorISE("error loading provisioner provName: force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provName: force", + }, + } + }, + "ok": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.ACME{ + ID: "acmeID", + Name: "provName", + }, nil + }, + } + prov := &linkedca.Provisioner{ + Id: "acmeID", + Type: linkedca.Provisioner_ACME, + Name: "provName", // TODO(hs): other fields too? + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "acmeID", id) + return prov, nil + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + db: db, + statusCode: 200, + err: nil, + prov: prov, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + db: tc.db, + } + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetProvisioner(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + prov := &linkedca.Provisioner{} + err := readProtoJSON(res.Body, prov) + assert.FatalError(t, err) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Provisioner{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.prov, prov, opts...) { + t.Errorf("h.GetProvisioner diff =\n%s", cmp.Diff(tc.prov, prov, opts...)) + } + }) + } +} + +func TestHandler_GetProvisioners(t *testing.T) { + type test struct { + ctx context.Context + auth api.LinkedAuthority + req *http.Request + statusCode int + err *admin.Error + resp GetProvisionersResponse + } + var tests = map[string]func(t *testing.T) test{ + "fail/parse-cursor": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo?limit=X", nil) + return test{ + ctx: context.Background(), + statusCode: 400, + req: req, + err: &admin.Error{ + Status: 400, + Type: admin.ErrorBadRequestType.String(), + Detail: "bad request", + Message: "error parsing cursor and limit from query params: limit 'X' is not an integer: strconv.Atoi: parsing \"X\": invalid syntax", + }, + } + }, + "fail/auth.GetProvisioners": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + auth := &api.MockAuthority{ + MockGetProvisioners: func(cursor string, limit int) (provisioner.List, string, error) { + assert.Equals(t, "", cursor) + assert.Equals(t, 0, limit) + return nil, "", errors.New("force") + }, + } + return test{ + ctx: context.Background(), + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: "", + Status: 500, + Detail: "", + Message: "The certificate authority encountered an Internal Server Error. Please see the certificate authority logs for more info.", + }, + } + }, + "ok": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + provisioners := provisioner.List{ + &provisioner.OIDC{ + Type: "OIDC", + Name: "oidcProv", + }, + &provisioner.ACME{ + Type: "ACME", + Name: "provName", + ForceCN: false, + RequireEAB: false, + }, + } + auth := &api.MockAuthority{ + MockGetProvisioners: func(cursor string, limit int) (provisioner.List, string, error) { + assert.Equals(t, "", cursor) + assert.Equals(t, 0, limit) + return provisioners, "nextCursorValue", nil + }, + } + return test{ + ctx: context.Background(), + req: req, + auth: auth, + statusCode: 200, + err: nil, + resp: GetProvisionersResponse{ + Provisioners: provisioners, + NextCursor: "nextCursorValue", + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetProvisioners(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + response := GetProvisionersResponse{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(provisioner.ACME{}, provisioner.OIDC{})} + if !cmp.Equal(tc.resp, response, opts...) { + t.Errorf("h.GetProvisioners diff =\n%s", cmp.Diff(tc.resp, response, opts...)) + } + }) + } +} + +func TestHandler_CreateProvisioner(t *testing.T) { + type test struct { + ctx context.Context + auth api.LinkedAuthority + body []byte + statusCode int + err *admin.Error + prov *linkedca.Provisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/readProtoJSON": func(t *testing.T) test { + body := []byte("{!?}") + return test{ + ctx: context.Background(), + body: body, + statusCode: 500, + err: &admin.Error{ // TODO(hs): this probably needs a better error + Type: "", + Status: 500, + Detail: "", + Message: "", + }, + } + }, + // TODO(hs): ValidateClaims can't be mocked atm + // "fail/authority.ValidateClaims": func(t *testing.T) test { + // return test{} + // }, + "fail/auth.StoreProvisioner": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockStoreProvisioner: func(ctx context.Context, prov *linkedca.Provisioner) error { + assert.Equals(t, "provID", prov.Id) + return errors.New("force") + }, + } + return test{ + ctx: context.Background(), + body: body, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error storing provisioner provName: force", + }, + } + }, + "ok": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockStoreProvisioner: func(ctx context.Context, prov *linkedca.Provisioner) error { + assert.Equals(t, "provID", prov.Id) + return nil + }, + } + return test{ + ctx: context.Background(), + body: body, + auth: auth, + statusCode: 201, + err: nil, + prov: prov, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.CreateProvisioner(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + prov := &linkedca.Provisioner{} + err := readProtoJSON(res.Body, prov) + assert.FatalError(t, err) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Provisioner{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.prov, prov, opts...) { + t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(tc.prov, prov, opts...)) + } + }) + } +} + +func TestHandler_DeleteProvisioner(t *testing.T) { + type test struct { + ctx context.Context + auth api.LinkedAuthority + req *http.Request + statusCode int + err *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadProvisionerByID": func(t *testing.T) test { + req := httptest.NewRequest("DELETE", "/foo?id=provID", nil) + chiCtx := chi.NewRouteContext() + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByID: func(id string) (provisioner.Interface, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provID: force", + }, + } + }, + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + req := httptest.NewRequest("DELETE", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provName: force", + }, + } + }, + "fail/auth.RemoveProvisioner": func(t *testing.T) test { + req := httptest.NewRequest("DELETE", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + Type: "OIDC", + }, nil + }, + MockRemoveProvisioner: func(ctx context.Context, id string) error { + assert.Equals(t, "provID", id) + return errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error removing provisioner provName: force", + }, + } + }, + "ok": func(t *testing.T) test { + req := httptest.NewRequest("DELETE", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + Type: "OIDC", + }, nil + }, + MockRemoveProvisioner: func(ctx context.Context, id string) error { + assert.Equals(t, "provID", id) + return nil + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 200, + err: nil, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.DeleteProvisioner(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + response := DeleteResponse{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equals(t, "ok", response.Status) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + }) + } +} + +func TestHandler_UpdateProvisioner(t *testing.T) { + type test struct { + ctx context.Context + auth api.LinkedAuthority + body []byte + db admin.DB + statusCode int + err *admin.Error + prov *linkedca.Provisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/readProtoJSON": func(t *testing.T) test { + body := []byte("{!?}") + return test{ + ctx: context.Background(), + body: body, + statusCode: 500, + err: &admin.Error{ // TODO(hs): this probably needs a better error + Type: "", + Status: 500, + Detail: "", + Message: "", + }, + } + }, + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + // return &provisioner.OIDC{ + // ID: "provID", + // Name: "provName", + // }, nil + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner from cached configuration 'provName': force", + }, + } + }, + "fail/db.GetProvisioner": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner from db 'provID': force", + }, + } + }, + "fail/change-id-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "differentProvID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner ID", + }, + } + }, + "fail/change-type-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_JWK, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner type", + }, + } + }, + "fail/change-authority-id-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "differentAuthorityID", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner authorityID", + }, + } + }, + "fail/change-createdAt-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + createdAt := time.Now() + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner createdAt", + }, + } + }, + "fail/change-deletedAt-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + createdAt := time.Now() + var deletedAt time.Time + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(time.Now()), + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner deletedAt", + }, + } + }, + // TODO(hs): ValidateClaims can't be mocked atm + //"fail/ValidateClaims": func(t *testing.T) test { return test{} }, + "fail/auth.UpdateProvisioner": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + createdAt := time.Now() + var deletedAt time.Time + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + assert.Equals(t, "provID", nu.Id) + assert.Equals(t, "provName", nu.Name) + return errors.New("force") + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: "", // TODO(hs): this error can be improved + Status: 500, + Detail: "", + Message: "", + }, + } + }, + "ok": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + createdAt := time.Now() + var deletedAt time.Time + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_OIDC{ + OIDC: &linkedca.OIDCProvisioner{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + }, + }, + }, + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + assert.Equals(t, "provID", nu.Id) + assert.Equals(t, "provName", nu.Name) + return nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 200, + prov: prov, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + db: tc.db, + } + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.UpdateProvisioner(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + prov := &linkedca.Provisioner{} + err := readProtoJSON(res.Body, prov) + assert.FatalError(t, err) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{ + cmpopts.IgnoreUnexported( + linkedca.Provisioner{}, linkedca.ProvisionerDetails{}, linkedca.ProvisionerDetails_OIDC{}, + linkedca.OIDCProvisioner{}, timestamppb.Timestamp{}, + ), + } + if !cmp.Equal(tc.prov, prov, opts...) { + t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(tc.prov, prov, opts...)) + } + }) + } +}