Add Admin API Middleware tests
This commit is contained in:
parent
d799359917
commit
bd169f505f
2 changed files with 238 additions and 7 deletions
|
@ -161,7 +161,7 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
h.requireEABEnabled(tc.next)(w, req)
|
h.requireEABEnabled(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
|
@ -178,8 +178,6 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// nothing to test when the requireEABEnabled middleware succeeds, currently
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -405,6 +403,13 @@ func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) {
|
||||||
fields fields
|
fields fields
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
name: "fail/reference-too-long",
|
||||||
|
fields: fields{
|
||||||
|
Reference: strings.Repeat("A", 257),
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "ok/empty-reference",
|
name: "ok/empty-reference",
|
||||||
fields: fields{
|
fields: fields{
|
||||||
|
@ -702,7 +707,7 @@ func TestHandler_CreateExternalAccountKey(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.CreateExternalAccountKey(w, req)
|
h.CreateExternalAccountKey(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
||||||
if res.StatusCode >= 400 {
|
if res.StatusCode >= 400 {
|
||||||
|
|
||||||
|
@ -797,7 +802,7 @@ func TestHandler_DeleteExternalAccountKey(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.DeleteExternalAccountKey(w, req)
|
h.DeleteExternalAccountKey(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
||||||
if res.StatusCode >= 400 {
|
if res.StatusCode >= 400 {
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
|
@ -1131,7 +1136,7 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetExternalAccountKeys(w, req)
|
h.GetExternalAccountKeys(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
||||||
if res.StatusCode >= 400 {
|
if res.StatusCode >= 400 {
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
|
@ -1158,7 +1163,7 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) {
|
||||||
|
|
||||||
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.EABKey{}, timestamppb.Timestamp{})}
|
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.EABKey{}, timestamppb.Timestamp{})}
|
||||||
if !cmp.Equal(tc.resp, response, opts...) {
|
if !cmp.Equal(tc.resp, response, opts...) {
|
||||||
t.Errorf("h.CreateExternalAccountKey diff =\n%s", cmp.Diff(tc.resp, response, opts...))
|
t.Errorf("h.GetExternalAccountKeys diff =\n%s", cmp.Diff(tc.resp, response, opts...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
226
authority/admin/api/middleware_test.go
Normal file
226
authority/admin/api/middleware_test.go
Normal file
|
@ -0,0 +1,226 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
|
"go.step.sm/linkedca"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandler_requireAPIEnabled(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
ctx context.Context
|
||||||
|
auth api.LinkedAuthority
|
||||||
|
next nextHTTP
|
||||||
|
err *admin.Error
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/auth.IsAdminAPIEnabled": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.Background(),
|
||||||
|
auth: &api.MockAuthority{
|
||||||
|
MockIsAdminAPIEnabled: func() bool {
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: &admin.Error{
|
||||||
|
Type: admin.ErrorNotImplementedType.String(),
|
||||||
|
Status: 501,
|
||||||
|
Detail: "not implemented",
|
||||||
|
Message: "administration API not enabled",
|
||||||
|
},
|
||||||
|
statusCode: 501,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
auth := &api.MockAuthority{
|
||||||
|
MockIsAdminAPIEnabled: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
next := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(nil) // mock response with status 200
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
ctx: context.Background(),
|
||||||
|
auth: auth,
|
||||||
|
next: next,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, prep := range tests {
|
||||||
|
tc := prep(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := &Handler{
|
||||||
|
auth: tc.auth,
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.requireAPIEnabled(tc.next)(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 {
|
||||||
|
err := admin.Error{}
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||||
|
|
||||||
|
assert.Equals(t, tc.err.Type, err.Type)
|
||||||
|
assert.Equals(t, tc.err.Message, err.Message)
|
||||||
|
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||||
|
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||||
|
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// nothing to test when the requireAPIEnabled middleware succeeds, currently
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
ctx context.Context
|
||||||
|
auth api.LinkedAuthority
|
||||||
|
req *http.Request
|
||||||
|
next nextHTTP
|
||||||
|
err *admin.Error
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/missing-authorization-token": func(t *testing.T) test {
|
||||||
|
req := httptest.NewRequest("GET", "/foo", nil)
|
||||||
|
req.Header["Authorization"] = []string{""}
|
||||||
|
return test{
|
||||||
|
ctx: context.Background(),
|
||||||
|
req: req,
|
||||||
|
statusCode: 401,
|
||||||
|
err: &admin.Error{
|
||||||
|
Type: admin.ErrorUnauthorizedType.String(),
|
||||||
|
Status: 401,
|
||||||
|
Detail: "unauthorized",
|
||||||
|
Message: "missing authorization header token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/auth.AuthorizeAdminToken": func(t *testing.T) test {
|
||||||
|
req := httptest.NewRequest("GET", "/foo", nil)
|
||||||
|
req.Header["Authorization"] = []string{"token"}
|
||||||
|
auth := &api.MockAuthority{
|
||||||
|
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||||
|
assert.Equals(t, "token", token)
|
||||||
|
return nil, admin.NewError(
|
||||||
|
admin.ErrorUnauthorizedType,
|
||||||
|
"not authorized",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
ctx: context.Background(),
|
||||||
|
auth: auth,
|
||||||
|
req: req,
|
||||||
|
statusCode: 401,
|
||||||
|
err: &admin.Error{
|
||||||
|
Type: admin.ErrorUnauthorizedType.String(),
|
||||||
|
Status: 401,
|
||||||
|
Detail: "unauthorized",
|
||||||
|
Message: "not authorized",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
req := httptest.NewRequest("GET", "/foo", nil)
|
||||||
|
req.Header["Authorization"] = []string{"token"}
|
||||||
|
createdAt := time.Now()
|
||||||
|
var deletedAt time.Time
|
||||||
|
admin := &linkedca.Admin{
|
||||||
|
Id: "adminID",
|
||||||
|
AuthorityId: "authorityID",
|
||||||
|
Subject: "admin",
|
||||||
|
ProvisionerId: "provID",
|
||||||
|
Type: linkedca.Admin_SUPER_ADMIN,
|
||||||
|
CreatedAt: timestamppb.New(createdAt),
|
||||||
|
DeletedAt: timestamppb.New(deletedAt),
|
||||||
|
}
|
||||||
|
auth := &api.MockAuthority{
|
||||||
|
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||||
|
assert.Equals(t, "token", token)
|
||||||
|
return admin, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
next := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin
|
||||||
|
adm, ok := a.(*linkedca.Admin)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("expected *linkedca.Admin; got %T", a)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
|
||||||
|
if !cmp.Equal(admin, adm, opts...) {
|
||||||
|
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...))
|
||||||
|
}
|
||||||
|
w.Write(nil) // mock response with status 200
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
ctx: context.Background(),
|
||||||
|
auth: auth,
|
||||||
|
req: req,
|
||||||
|
next: next,
|
||||||
|
statusCode: 200,
|
||||||
|
err: nil,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, prep := range tests {
|
||||||
|
tc := prep(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := &Handler{
|
||||||
|
auth: tc.auth,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := tc.req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 {
|
||||||
|
err := admin.Error{}
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||||
|
|
||||||
|
assert.Equals(t, tc.err.Type, err.Type)
|
||||||
|
assert.Equals(t, tc.err.Message, err.Message)
|
||||||
|
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||||
|
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||||
|
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue