forked from TrueCloudLab/certificates
Merge branch 'master' into ssh-renew-provisioner
This commit is contained in:
commit
1be74eca62
62 changed files with 2601 additions and 1833 deletions
|
@ -67,8 +67,11 @@ func (u *UpdateAccountRequest) Validate() error {
|
|||
}
|
||||
|
||||
// NewAccount is the handler resource for creating new ACME accounts.
|
||||
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -114,7 +117,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
eak, err := h.validateExternalAccountBinding(ctx, &nar)
|
||||
eak, err := validateExternalAccountBinding(ctx, &nar)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -125,7 +128,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
Contact: nar.Contact,
|
||||
Status: acme.StatusValid,
|
||||
}
|
||||
if err := h.db.CreateAccount(ctx, acc); err != nil {
|
||||
if err := db.CreateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
||||
return
|
||||
}
|
||||
|
@ -135,7 +138,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||
if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
|
||||
return
|
||||
}
|
||||
|
@ -146,15 +149,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
httpStatus = http.StatusOK
|
||||
}
|
||||
|
||||
h.linker.LinkAccount(ctx, acc)
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
|
||||
w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
|
||||
render.JSONStatus(w, acc, httpStatus)
|
||||
}
|
||||
|
||||
// GetOrUpdateAccount is the api for updating an ACME account.
|
||||
func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -186,16 +192,16 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
|||
acc.Contact = uar.Contact
|
||||
}
|
||||
|
||||
if err := h.db.UpdateAccount(ctx, acc); err != nil {
|
||||
if err := db.UpdateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h.linker.LinkAccount(ctx, acc)
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
|
||||
render.JSON(w, acc)
|
||||
}
|
||||
|
||||
|
@ -209,8 +215,11 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
|||
}
|
||||
|
||||
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
|
||||
func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||
func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -221,13 +230,14 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
||||
return
|
||||
}
|
||||
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
|
||||
|
||||
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkOrdersByAccountID(ctx, orders)
|
||||
linker.LinkOrdersByAccountID(ctx, orders)
|
||||
|
||||
render.JSON(w, orders)
|
||||
logOrdersByAccount(w, orders)
|
||||
|
|
|
@ -31,6 +31,22 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
type fakeProvisioner struct{}
|
||||
|
||||
func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil }
|
||||
func (*fakeProvisioner) GetID() string { return "" }
|
||||
func (*fakeProvisioner) GetName() string { return "" }
|
||||
func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 }
|
||||
func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil }
|
||||
|
||||
func newProv() acme.Provisioner {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
|
@ -320,10 +336,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: accID}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
|
||||
|
@ -339,11 +354,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrdersByAccountID(w, req)
|
||||
GetOrdersByAccountID(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -387,6 +402,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -395,6 +411,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -403,6 +420,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to "+
|
||||
|
@ -417,6 +435,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||
|
@ -429,8 +448,9 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -442,9 +462,10 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jwk expected in request context"),
|
||||
|
@ -456,10 +477,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jwk expected in request context"),
|
||||
|
@ -478,9 +500,9 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"),
|
||||
|
@ -495,7 +517,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -525,18 +547,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
scepProvisioner := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
|
||||
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"),
|
||||
|
@ -575,8 +590,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
eak := &acme.ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
|
@ -623,8 +637,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||
|
@ -659,11 +672,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
acc: acc,
|
||||
statusCode: 200,
|
||||
|
@ -688,8 +701,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = false
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||
|
@ -743,8 +755,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -783,11 +794,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.NewAccount(w, req)
|
||||
NewAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -838,6 +849,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -846,6 +858,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -854,6 +867,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
"fail/no-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -863,6 +877,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -872,6 +887,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
|
||||
|
@ -886,6 +902,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||
|
@ -918,10 +935,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||
|
@ -938,11 +954,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
uar := &UpdateAccountRequest{}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
|
@ -953,10 +969,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||
|
@ -970,11 +985,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok/post-as-get": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
|
@ -983,11 +998,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrUpdateAccount(w, req)
|
||||
GetOrUpdateAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
|
|
@ -17,7 +17,7 @@ type ExternalAccountBinding struct {
|
|||
}
|
||||
|
||||
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
|
||||
func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
||||
func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
|
||||
|
@ -48,7 +48,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc
|
|||
return nil, acmeErr
|
||||
}
|
||||
|
||||
externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||
if err != nil {
|
||||
if _, ok := err.(*acme.Error); ok {
|
||||
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
|
||||
|
@ -111,7 +112,6 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
|
|||
// o The "nonce" field MUST NOT be present
|
||||
// o The "url" field MUST be set to the same value as the outer JWS
|
||||
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
|
||||
|
||||
if jws == nil {
|
||||
return "", acme.NewErrorISE("no JWS provided")
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
func Test_keysAreEqual(t *testing.T) {
|
||||
|
@ -100,8 +99,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
|
@ -145,8 +143,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
createdAt := time.Now()
|
||||
return test{
|
||||
|
@ -191,17 +188,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
scepProvisioner := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
|
||||
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
|
||||
|
@ -220,8 +210,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
|
@ -266,8 +255,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
|
@ -312,8 +300,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -360,8 +347,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -410,8 +396,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -460,8 +445,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -510,8 +494,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
createdAt := time.Now()
|
||||
return test{
|
||||
|
@ -568,8 +551,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -616,8 +598,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
createdAt := time.Now()
|
||||
boundAt := time.Now().Add(1 * time.Second)
|
||||
|
@ -676,8 +657,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -734,8 +714,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -789,8 +768,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -845,8 +823,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, nil)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -873,10 +850,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
db: tc.db,
|
||||
}
|
||||
got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar)
|
||||
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
|
||||
got, err := validateExternalAccountBinding(ctx, tc.nar)
|
||||
wantErr := tc.err != nil
|
||||
gotErr := err != nil
|
||||
if wantErr != gotErr {
|
||||
|
|
|
@ -2,12 +2,10 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -16,6 +14,7 @@ import (
|
|||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
|
@ -39,111 +38,152 @@ type payloadInfo struct {
|
|||
isEmptyJSON bool
|
||||
}
|
||||
|
||||
// Handler is the ACME API request handler.
|
||||
type Handler struct {
|
||||
db acme.DB
|
||||
backdate provisioner.Duration
|
||||
ca acme.CertificateAuthority
|
||||
linker Linker
|
||||
validateChallengeOptions *acme.ValidateChallengeOptions
|
||||
prerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
}
|
||||
|
||||
// HandlerOptions required to create a new ACME API request handler.
|
||||
type HandlerOptions struct {
|
||||
Backdate provisioner.Duration
|
||||
// DB storage backend that impements the acme.DB interface.
|
||||
// DB storage backend that implements the acme.DB interface.
|
||||
//
|
||||
// Deprecated: use acme.NewContex(context.Context, acme.DB)
|
||||
DB acme.DB
|
||||
|
||||
// CA is the certificate authority interface.
|
||||
//
|
||||
// Deprecated: use authority.NewContext(context.Context, *authority.Authority)
|
||||
CA acme.CertificateAuthority
|
||||
|
||||
// Backdate is the duration that the CA will subtract from the current time
|
||||
// to set the NotBefore in the certificate.
|
||||
Backdate provisioner.Duration
|
||||
|
||||
// DNS the host used to generate accurate ACME links. By default the authority
|
||||
// will use the Host from the request, so this value will only be used if
|
||||
// request.Host is empty.
|
||||
DNS string
|
||||
|
||||
// Prefix is a URL path prefix under which the ACME api is served. This
|
||||
// prefix is required to generate accurate ACME links.
|
||||
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
|
||||
// "acme" is the prefix from which the ACME api is accessed.
|
||||
Prefix string
|
||||
CA acme.CertificateAuthority
|
||||
|
||||
// PrerequisitesChecker checks if all prerequisites for serving ACME are
|
||||
// met by the CA configuration.
|
||||
PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
}
|
||||
|
||||
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// handler is the ACME API request handler.
|
||||
type handler struct {
|
||||
opts *HandlerOptions
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface. For backward compatibility
|
||||
// this route adds will add a new middleware that will set the ACME components
|
||||
// on the context.
|
||||
//
|
||||
// Note: this method is deprecated in step-ca, other applications can still use
|
||||
// this to support ACME, but the recommendation is to use use
|
||||
// api.Route(api.Router) and acme.NewContext() instead.
|
||||
func (h *handler) Route(r api.Router) {
|
||||
client := acme.NewClient()
|
||||
linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix)
|
||||
route(r, func(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
|
||||
ctx = authority.NewContext(ctx, ca)
|
||||
}
|
||||
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// NewHandler returns a new ACME API handler.
|
||||
func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
prerequisitesChecker := func(ctx context.Context) (bool, error) {
|
||||
// by default all prerequisites are met
|
||||
return true, nil
|
||||
}
|
||||
if ops.PrerequisitesChecker != nil {
|
||||
prerequisitesChecker = ops.PrerequisitesChecker
|
||||
}
|
||||
return &Handler{
|
||||
ca: ops.CA,
|
||||
db: ops.DB,
|
||||
backdate: ops.Backdate,
|
||||
linker: NewLinker(ops.DNS, ops.Prefix),
|
||||
validateChallengeOptions: &acme.ValidateChallengeOptions{
|
||||
HTTPGet: client.Get,
|
||||
LookupTxt: net.LookupTXT,
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(dialer, network, addr, config)
|
||||
},
|
||||
},
|
||||
prerequisitesChecker: prerequisitesChecker,
|
||||
//
|
||||
// Note: this method is deprecated in step-ca, other applications can still use
|
||||
// this to support ACME, but the recommendation is to use use
|
||||
// api.Route(api.Router) and acme.NewContext() instead.
|
||||
func NewHandler(opts HandlerOptions) api.RouterHandler {
|
||||
return &handler{
|
||||
opts: &opts,
|
||||
}
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
getPath := h.linker.GetUnescapedPathSuffix
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||
// Route traffic and implement the Router interface. This method requires that
|
||||
// all the acme components, authority, db, client, linker, and prerequisite
|
||||
// checker to be present in the context.
|
||||
func Route(r api.Router) {
|
||||
route(r, nil)
|
||||
}
|
||||
|
||||
func route(r api.Router, middleware func(next nextHTTP) nextHTTP) {
|
||||
commonMiddleware := func(next nextHTTP) nextHTTP {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
// Linker middleware gets the provisioner and current url from the
|
||||
// request and sets them in the context.
|
||||
linker := acme.MustLinkerFromContext(r.Context())
|
||||
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
|
||||
}
|
||||
if middleware != nil {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
return handler
|
||||
}
|
||||
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
||||
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))))
|
||||
return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
|
||||
}
|
||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
|
||||
return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))
|
||||
return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next)))
|
||||
return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
|
||||
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
|
||||
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
|
||||
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented))
|
||||
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder))
|
||||
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
||||
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID)))
|
||||
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
||||
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
|
||||
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
|
||||
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
||||
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert))
|
||||
getPath := acme.GetUnescapedPathSuffix
|
||||
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
|
||||
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
|
||||
extractPayloadByJWK(NewAccount))
|
||||
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(GetOrUpdateAccount))
|
||||
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(NotImplemented))
|
||||
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
|
||||
extractPayloadByKid(NewOrder))
|
||||
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrder)))
|
||||
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
|
||||
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(FinalizeOrder))
|
||||
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetAuthorization)))
|
||||
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
|
||||
extractPayloadByKid(GetChallenge))
|
||||
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetCertificate)))
|
||||
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
|
||||
extractPayloadByKidOrJWK(RevokeCert))
|
||||
}
|
||||
|
||||
// GetNonce just sets the right header since a Nonce is added to each response
|
||||
// by middleware by default.
|
||||
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
func GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
|
@ -179,7 +219,7 @@ func (d *Directory) ToLog() (interface{}, error) {
|
|||
|
||||
// GetDirectory is the ACME resource for returning a directory configuration
|
||||
// for client configuration.
|
||||
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
func GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -187,12 +227,13 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
render.JSON(w, &Directory{
|
||||
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
|
||||
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
|
||||
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
|
||||
RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType),
|
||||
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType),
|
||||
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
|
||||
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
|
||||
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
|
||||
RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
|
||||
KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
|
||||
Meta: Meta{
|
||||
ExternalAccountRequired: acmeProv.RequireEAB,
|
||||
},
|
||||
|
@ -201,19 +242,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
||||
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
||||
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
func NotImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
||||
}
|
||||
|
||||
// GetAuthorization ACME api for retrieving an Authz.
|
||||
func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
func GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||
az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
||||
return
|
||||
|
@ -223,20 +267,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
|||
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
||||
return
|
||||
}
|
||||
if err = az.UpdateStatus(ctx, h.db); err != nil {
|
||||
if err = az.UpdateStatus(ctx, db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkAuthorization(ctx, az)
|
||||
linker.LinkAuthorization(ctx, az)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
|
||||
render.JSON(w, az)
|
||||
}
|
||||
|
||||
// GetChallenge ACME api for retrieving a Challenge.
|
||||
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -257,7 +304,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
|||
// we'll just ignore the body.
|
||||
|
||||
azID := chi.URLParam(r, "authzID")
|
||||
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||
ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
||||
return
|
||||
|
@ -273,29 +320,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
|
||||
if err = ch.Validate(ctx, db, jwk); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkChallenge(ctx, ch, azID)
|
||||
linker.LinkChallenge(ctx, ch, azID)
|
||||
|
||||
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
|
||||
render.JSON(w, ch)
|
||||
}
|
||||
|
||||
// GetCertificate ACME api for retrieving a Certificate.
|
||||
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
func GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
certID := chi.URLParam(r, "certID")
|
||||
|
||||
cert, err := h.db.GetCertificate(ctx, certID)
|
||||
certID := chi.URLParam(r, "certID")
|
||||
cert, err := db.GetCertificate(ctx, certID)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
||||
return
|
||||
|
|
|
@ -3,6 +3,7 @@ package api
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
|
@ -19,11 +20,33 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
get func(url string) (*http.Response, error)
|
||||
lookupTxt func(name string) ([]string, error)
|
||||
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) }
|
||||
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
|
||||
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return m.tlsDial(network, addr, config)
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_GetNonce(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -38,10 +61,10 @@ func TestHandler_GetNonce(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
// h := &Handler{}
|
||||
w := httptest.NewRecorder()
|
||||
req.Method = tt.name
|
||||
h.GetNonce(w, req)
|
||||
GetNonce(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -52,7 +75,8 @@ func TestHandler_GetNonce(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandler_GetDirectory(t *testing.T) {
|
||||
linker := NewLinker("ca.smallstep.com", "acme")
|
||||
linker := acme.NewLinker("ca.smallstep.com", "acme")
|
||||
_ = linker
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
|
@ -61,23 +85,14 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"),
|
||||
err: acme.NewErrorISE("provisioner is not in context"),
|
||||
}
|
||||
},
|
||||
"fail/different-provisioner": func(t *testing.T) test {
|
||||
prov := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
|
@ -88,8 +103,7 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
|
@ -108,8 +122,7 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
|
@ -130,11 +143,11 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: linker}
|
||||
ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetDirectory(w, req)
|
||||
GetDirectory(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -219,7 +232,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
|
@ -285,10 +298,9 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||
|
@ -304,11 +316,11 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAuthorization(w, req)
|
||||
GetAuthorization(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -447,11 +459,11 @@ func TestHandler_GetCertificate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetCertificate(w, req)
|
||||
GetCertificate(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -491,7 +503,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
|
||||
type test struct {
|
||||
db acme.DB
|
||||
vco *acme.ValidateChallengeOptions
|
||||
vc acme.Client
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
ch *acme.Challenge
|
||||
|
@ -500,6 +512,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -507,6 +520,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), accContextKey, nil),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -516,6 +530,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -523,10 +538,11 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -534,7 +550,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -553,7 +569,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -572,7 +588,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/no-jwk": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -591,7 +607,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/nil-jwk": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||
|
@ -611,7 +627,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/validate-challenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
|
@ -639,8 +655,8 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
return acme.NewErrorISE("force")
|
||||
},
|
||||
},
|
||||
vco: &acme.ValidateChallengeOptions{
|
||||
HTTPGet: func(string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -651,14 +667,13 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
_pub := _jwk.Public()
|
||||
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -690,8 +705,8 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
URL: u,
|
||||
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
||||
},
|
||||
vco: &acme.ValidateChallengeOptions{
|
||||
HTTPGet: func(string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -703,11 +718,11 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetChallenge(w, req)
|
||||
GetChallenge(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
|
||||
|
@ -31,39 +30,11 @@ func logNonce(w http.ResponseWriter, nonce string) {
|
|||
}
|
||||
}
|
||||
|
||||
// baseURLFromRequest determines the base URL which should be used for
|
||||
// constructing link URLs in e.g. the ACME directory result by taking the
|
||||
// request Host into consideration.
|
||||
//
|
||||
// If the Request.Host is an empty string, we return an empty string, to
|
||||
// indicate that the configured URL values should be used instead. If this
|
||||
// function returns a non-empty result, then this should be used in
|
||||
// constructing ACME link URLs.
|
||||
func baseURLFromRequest(r *http.Request) *url.URL {
|
||||
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
|
||||
// for an implementation that allows HTTP requests using the x-forwarded-proto
|
||||
// header.
|
||||
|
||||
if r.Host == "" {
|
||||
return nil
|
||||
}
|
||||
return &url.URL{Scheme: "https", Host: r.Host}
|
||||
}
|
||||
|
||||
// baseURLFromRequest is a middleware that extracts and caches the baseURL
|
||||
// from the request.
|
||||
// E.g. https://ca.smallstep.com/
|
||||
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// addNonce is a middleware that adds a nonce to the response header.
|
||||
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||
func addNonce(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
nonce, err := h.db.CreateNonce(r.Context())
|
||||
db := acme.MustDatabaseFromContext(r.Context())
|
||||
nonce, err := db.CreateNonce(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -77,25 +48,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
|||
|
||||
// addDirLink is a middleware that adds a 'Link' response reader with the
|
||||
// directory index url.
|
||||
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||
func addDirLink(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index"))
|
||||
ctx := r.Context()
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// verifyContentType is a middleware that verifies that content type is
|
||||
// application/jose+json.
|
||||
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||
func verifyContentType(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var expected []string
|
||||
p, err := provisionerFromContext(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
|
||||
u := &url.URL{
|
||||
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
|
||||
}
|
||||
|
||||
var expected []string
|
||||
if strings.Contains(r.URL.String(), u.EscapedPath()) {
|
||||
// GET /certificate requests allow a greater range of content types.
|
||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||
|
@ -117,7 +94,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
|
||||
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||
func parseJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
|
@ -149,10 +126,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
|||
// * “nonce” (defined in Section 6.5)
|
||||
// * “url” (defined in Section 6.4)
|
||||
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||
func validateJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(r.Context())
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -202,7 +181,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// Check the validity/freshness of the Nonce.
|
||||
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||
if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
@ -235,10 +214,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
|||
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
|
||||
// in the context. Make sure to parse and validate the JWS before running this
|
||||
// middleware.
|
||||
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||
func extractJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(r.Context())
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -264,7 +245,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
|||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
|
||||
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
|
||||
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||
acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||
switch {
|
||||
case errors.Is(err, acme.ErrNotFound):
|
||||
// For NewAccount and Revoke requests ...
|
||||
|
@ -283,38 +264,15 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
|||
}
|
||||
}
|
||||
|
||||
// lookupProvisioner loads the provisioner associated with the request.
|
||||
// Responds 404 if the provisioner does not exist.
|
||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
p, err := h.ca.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// checkPrerequisites checks if all prerequisites for serving ACME
|
||||
// are met by the CA configuration.
|
||||
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
|
||||
func checkPrerequisites(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ok, err := h.prerequisitesChecker(ctx)
|
||||
// If the function is not set assume that all prerequisites are met.
|
||||
checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
|
||||
if ok {
|
||||
ok, err := checkFunc(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
return
|
||||
|
@ -323,23 +281,27 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
|
|||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||
// kid parameter of the signed payload.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||
func lookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
|
||||
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
|
||||
kid := jws.Signatures[0].Protected.KeyID
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
|
@ -349,7 +311,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||
acc, err := h.db.GetAccount(ctx, accID)
|
||||
acc, err := db.GetAccount(ctx, accID)
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||
|
@ -372,7 +334,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
|||
|
||||
// extractOrLookupJWK forwards handling to either extractJWK or
|
||||
// lookupJWK based on the presence of a JWK or a KID, respectively.
|
||||
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||
func extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
|
@ -385,13 +347,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
|||
// and it can be used to check if a JWK exists. This flow is used when the ACME client
|
||||
// signed the payload with a certificate private key.
|
||||
if canExtractJWKFrom(jws) {
|
||||
h.extractJWK(next)(w, r)
|
||||
extractJWK(next)(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
|
||||
// signed the payload with an account private key.
|
||||
h.lookupJWK(next)(w, r)
|
||||
lookupJWK(next)(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -408,7 +370,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
|
|||
|
||||
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
|
@ -440,7 +402,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
|
||||
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
||||
func isPostAsGet(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, err := payloadFromContext(r.Context())
|
||||
if err != nil {
|
||||
|
@ -462,16 +424,12 @@ type ContextKey string
|
|||
const (
|
||||
// accContextKey account key
|
||||
accContextKey = ContextKey("acc")
|
||||
// baseURLContextKey baseURL key
|
||||
baseURLContextKey = ContextKey("baseURL")
|
||||
// jwsContextKey jws key
|
||||
jwsContextKey = ContextKey("jws")
|
||||
// jwkContextKey jwk key
|
||||
jwkContextKey = ContextKey("jwk")
|
||||
// payloadContextKey payload key
|
||||
payloadContextKey = ContextKey("payload")
|
||||
// provisionerContextKey provisioner key
|
||||
provisionerContextKey = ContextKey("provisioner")
|
||||
)
|
||||
|
||||
// accountFromContext searches the context for an ACME account. Returns the
|
||||
|
@ -484,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) {
|
|||
return val, nil
|
||||
}
|
||||
|
||||
// baseURLFromContext returns the baseURL if one is stored in the context.
|
||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
|
||||
if !ok || val == nil {
|
||||
return nil
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
|
||||
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
|
||||
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
|
||||
|
@ -514,29 +463,26 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
|
|||
// provisionerFromContext searches the context for a provisioner. Returns the
|
||||
// provisioner or an error.
|
||||
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
|
||||
val := ctx.Value(provisionerContextKey)
|
||||
if val == nil {
|
||||
p, ok := acme.ProvisionerFromContext(ctx)
|
||||
if !ok || p == nil {
|
||||
return nil, acme.NewErrorISE("provisioner expected in request context")
|
||||
}
|
||||
pval, ok := val.(acme.Provisioner)
|
||||
if !ok || pval == nil {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
return pval, nil
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
|
||||
// pointer to an ACME provisioner or an error.
|
||||
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
p, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
acmeProv, ok := prov.(*provisioner.ACME)
|
||||
if !ok || acmeProv == nil {
|
||||
ap, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
return acmeProv, nil
|
||||
|
||||
return ap, nil
|
||||
}
|
||||
|
||||
// payloadFromContext searches the context for a payload. Returns the payload
|
||||
|
|
|
@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) {
|
|||
w.Write(testBody)
|
||||
}
|
||||
|
||||
func Test_baseURLFromRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
targetURL string
|
||||
expectedResult *url.URL
|
||||
requestPreparer func(*http.Request)
|
||||
}{
|
||||
{
|
||||
"HTTPS host pass-through failed.",
|
||||
"https://my.dummy.host",
|
||||
&url.URL{Scheme: "https", Host: "my.dummy.host"},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Port pass-through failed",
|
||||
"https://host.with.port:8080",
|
||||
&url.URL{Scheme: "https", Host: "host.with.port:8080"},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Explicit host from Request.Host was not used.",
|
||||
"https://some.target.host:8080",
|
||||
&url.URL{Scheme: "https", Host: "proxied.host"},
|
||||
func(r *http.Request) {
|
||||
r.Host = "proxied.host"
|
||||
},
|
||||
},
|
||||
{
|
||||
"Missing Request.Host value did not result in empty string result.",
|
||||
"https://some.host",
|
||||
nil,
|
||||
func(r *http.Request) {
|
||||
r.Host = ""
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
request := httptest.NewRequest("GET", tc.targetURL, nil)
|
||||
if tc.requestPreparer != nil {
|
||||
tc.requestPreparer(request)
|
||||
}
|
||||
result := baseURLFromRequest(request)
|
||||
if result == nil || tc.expectedResult == nil {
|
||||
assert.Equals(t, result, tc.expectedResult)
|
||||
} else if result.String() != tc.expectedResult.String() {
|
||||
t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String())
|
||||
}
|
||||
})
|
||||
func newBaseContext(ctx context.Context, args ...interface{}) context.Context {
|
||||
for _, a := range args {
|
||||
switch v := a.(type) {
|
||||
case acme.DB:
|
||||
ctx = acme.NewDatabaseContext(ctx, v)
|
||||
case acme.Linker:
|
||||
ctx = acme.NewLinkerContext(ctx, v)
|
||||
case acme.PrerequisitesChecker:
|
||||
ctx = acme.NewPrerequisitesCheckerContext(ctx, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_baseURLFromRequest(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req.Host = "test.ca.smallstep.com:8080"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
next := func(w http.ResponseWriter, r *http.Request) {
|
||||
bu := baseURLFromContext(r.Context())
|
||||
if assert.NotNil(t, bu) {
|
||||
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
|
||||
assert.Equals(t, bu.Scheme, "https")
|
||||
}
|
||||
}
|
||||
|
||||
h.baseURLFromRequest(next)(w, req)
|
||||
|
||||
req = httptest.NewRequest("GET", "/foo", nil)
|
||||
req.Host = ""
|
||||
|
||||
next = func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, baseURLFromContext(r.Context()), nil)
|
||||
}
|
||||
|
||||
h.baseURLFromRequest(next)(w, req)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func TestHandler_addNonce(t *testing.T) {
|
||||
|
@ -139,10 +74,10 @@ func TestHandler_addNonce(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
ctx := newBaseContext(context.Background(), tc.db)
|
||||
req := httptest.NewRequest("GET", u, nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.addNonce(testNext)(w, req)
|
||||
addNonce(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) {
|
|||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
type test struct {
|
||||
link string
|
||||
linker Linker
|
||||
statusCode int
|
||||
ctx context.Context
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"ok": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
|
||||
statusCode: 200,
|
||||
|
@ -195,11 +128,10 @@ func TestHandler_addDirLink(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: tc.linker}
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.addDirLink(testNext)(w, req)
|
||||
addDirLink(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
||||
type test struct {
|
||||
h Handler
|
||||
ctx context.Context
|
||||
contentType string
|
||||
err *acme.Error
|
||||
|
@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/provisioner-not-set": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
url: u,
|
||||
ctx: context.Background(),
|
||||
contentType: "foo",
|
||||
|
@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
},
|
||||
"fail/general-bad-content-type": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
url: u,
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "foo",
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
|
||||
|
@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
},
|
||||
"fail/certificate-bad-content-type": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "foo",
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
|
||||
|
@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "application/jose+json",
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/certificate/pkix-cert": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "application/pkix-cert",
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/certificate/jose+json": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "application/jose+json",
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/certificate/pkcs7-mime": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "application/pkcs7-mime",
|
||||
statusCode: 200,
|
||||
}
|
||||
|
@ -326,7 +236,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
req = req.WithContext(tc.ctx)
|
||||
req.Header.Add("Content-Type", tc.contentType)
|
||||
w := httptest.NewRecorder()
|
||||
tc.h.verifyContentType(testNext)(w, req)
|
||||
verifyContentType(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -390,11 +300,11 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
// h := &Handler{}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.isPostAsGet(testNext)(w, req)
|
||||
isPostAsGet(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -481,10 +391,10 @@ func TestHandler_parseJWS(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
// h := &Handler{}
|
||||
req := httptest.NewRequest("GET", u, tc.body)
|
||||
w := httptest.NewRecorder()
|
||||
h.parseJWS(tc.next)(w, req)
|
||||
parseJWS(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -679,11 +589,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
// h := &Handler{}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||
verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
linker Linker
|
||||
linker acme.Linker
|
||||
db acme.DB
|
||||
ctx context.Context
|
||||
next func(http.ResponseWriter, *http.Request)
|
||||
|
@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
}
|
||||
},
|
||||
"fail/nil-jws": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
_jws, err := _signer.Sign([]byte("baz"))
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
|
||||
|
@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
_parsed, err := jose.ParseJWS(_raw)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _parsed)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
|
||||
}
|
||||
},
|
||||
"fail/account-not-found": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||
assert.Equals(t, accID, accID)
|
||||
|
@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/GetAccount-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
|
@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-valid": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "deactivated"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
|
@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "valid", Key: jwk}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
|
@ -881,11 +791,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: tc.linker}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.lookupJWK(tc.next)(w, req)
|
||||
lookupJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
}
|
||||
},
|
||||
"fail/nil-jws": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
|
||||
|
@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
|
||||
}
|
||||
},
|
||||
"fail/GetAccountByKey-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-valid": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "deactivated"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "valid"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok/no-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1077,11 +991,11 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
ctx := newBaseContext(tc.ctx, tc.db)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractJWK(tc.next)(w, req)
|
||||
extractJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
"fail/nil-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, nil),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
"fail/no-signature": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
|
||||
|
@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
|
||||
|
@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
|
||||
|
@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
|
||||
|
@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
|
||||
|
@ -1444,11 +1365,11 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
ctx := newBaseContext(tc.ctx, tc.db)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.validateJWS(tc.next)(w, req)
|
||||
validateJWS(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
u := "https://ca.smallstep.com/acme/account"
|
||||
type test struct {
|
||||
db acme.DB
|
||||
linker Linker
|
||||
linker acme.Linker
|
||||
statusCode int
|
||||
ctx context.Context
|
||||
err *acme.Error
|
||||
|
@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
||||
assert.Equals(t, kid, pub.KeyID)
|
||||
|
@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
linker: NewLinker("test.ca.smallstep.com", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
|
@ -1628,11 +1548,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: tc.linker}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractOrLookupJWK(tc.next)(w, req)
|
||||
extractOrLookupJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||
baseURL, provName)
|
||||
type test struct {
|
||||
linker Linker
|
||||
linker acme.Linker
|
||||
ctx context.Context
|
||||
prerequisitesChecker func(context.Context) (bool, error)
|
||||
next func(http.ResponseWriter, *http.Request)
|
||||
|
@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/prerequisites-nok": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1717,11 +1634,11 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
|
||||
ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.checkPrerequisites(tc.next)(w, req)
|
||||
checkPrerequisites(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
|
|
@ -72,8 +72,12 @@ var defaultOrderExpiry = time.Hour * 24
|
|||
var defaultOrderBackdate = time.Minute
|
||||
|
||||
// NewOrder ACME api for creating a new order.
|
||||
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ca := mustAuthority(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -113,7 +117,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var eak *acme.ExternalAccountKey
|
||||
if acmeProv.RequireEAB {
|
||||
if eak, err = h.db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
|
||||
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
|
||||
return
|
||||
}
|
||||
|
@ -138,7 +142,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
// evaluate the authority level policy
|
||||
if err = h.ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
|
||||
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
|
@ -164,7 +168,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
ExpiresAt: o.ExpiresAt,
|
||||
Status: acme.StatusPending,
|
||||
}
|
||||
if err := h.newAuthorization(ctx, az); err != nil {
|
||||
if err := newAuthorization(ctx, az); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
@ -183,14 +187,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
|
||||
}
|
||||
|
||||
if err := h.db.CreateOrder(ctx, o); err != nil {
|
||||
if err := db.CreateOrder(ctx, o); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkOrder(ctx, o)
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSONStatus(w, o, http.StatusCreated)
|
||||
}
|
||||
|
||||
|
@ -208,7 +212,7 @@ func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error
|
|||
return policy.NewX509PolicyEngine(eak.Policy)
|
||||
}
|
||||
|
||||
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
func newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
if strings.HasPrefix(az.Identifier.Value, "*.") {
|
||||
az.Wildcard = true
|
||||
az.Identifier = acme.Identifier{
|
||||
|
@ -224,6 +228,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
|
|||
if err != nil {
|
||||
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
|
||||
}
|
||||
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
az.Challenges = make([]*acme.Challenge, len(chTypes))
|
||||
for i, typ := range chTypes {
|
||||
ch := &acme.Challenge{
|
||||
|
@ -233,20 +239,23 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
|
|||
Token: az.Token,
|
||||
Status: acme.StatusPending,
|
||||
}
|
||||
if err := h.db.CreateChallenge(ctx, ch); err != nil {
|
||||
if err := db.CreateChallenge(ctx, ch); err != nil {
|
||||
return acme.WrapErrorISE(err, "error creating challenge")
|
||||
}
|
||||
az.Challenges[i] = ch
|
||||
}
|
||||
if err = h.db.CreateAuthorization(ctx, az); err != nil {
|
||||
if err = db.CreateAuthorization(ctx, az); err != nil {
|
||||
return acme.WrapErrorISE(err, "error creating authorization")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOrder ACME api for retrieving an order.
|
||||
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
func GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -257,7 +266,8 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
|
@ -272,20 +282,23 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
|||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
return
|
||||
}
|
||||
if err = o.UpdateStatus(ctx, h.db); err != nil {
|
||||
if err = o.UpdateStatus(ctx, db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkOrder(ctx, o)
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
}
|
||||
|
||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
// FinalizeOrder attempts to finalize an order and create a certificate.
|
||||
func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -312,7 +325,7 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
|
@ -327,14 +340,16 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
|||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
return
|
||||
}
|
||||
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
|
||||
|
||||
ca := mustAuthority(ctx)
|
||||
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkOrder(ctx, o)
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
}
|
||||
|
||||
|
|
|
@ -280,15 +280,17 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -298,6 +300,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -305,9 +308,10 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), nil)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -315,7 +319,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/db.GetOrder-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
|
@ -329,7 +333,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
|
@ -345,7 +349,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/provisioner-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
|
@ -361,7 +365,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/order-update-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
|
@ -385,10 +389,9 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
|
||||
|
@ -425,11 +428,11 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrder(w, req)
|
||||
GetOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -640,8 +643,8 @@ func TestHandler_newAuthorization(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
h := &Handler{db: tc.db}
|
||||
if err := h.newAuthorization(context.Background(), tc.az); err != nil {
|
||||
ctx := newBaseContext(context.Background(), tc.db)
|
||||
if err := newAuthorization(ctx, tc.az); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
|
@ -682,15 +685,17 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -700,6 +705,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -707,9 +713,10 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -718,8 +725,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -727,21 +735,23 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("paylod does not exist"),
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"),
|
||||
|
@ -752,10 +762,11 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
fr := &NewOrderRequest{}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
|
||||
|
@ -770,7 +781,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, &acme.MockProvisioner{})
|
||||
ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{})
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
|
@ -798,7 +809,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
|
@ -826,7 +837,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
|
@ -862,7 +873,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
|
@ -905,7 +916,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
|
@ -948,7 +959,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
|
@ -986,7 +997,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
|
@ -1020,7 +1031,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
var (
|
||||
|
@ -1096,10 +1107,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3, ch4 **acme.Challenge
|
||||
az1ID, az2ID *string
|
||||
|
@ -1217,10 +1227,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
|
@ -1315,10 +1324,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
|
@ -1412,10 +1420,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
|
@ -1510,10 +1517,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
|
@ -1611,10 +1617,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
|
@ -1701,11 +1706,12 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
|
||||
mockMustAuthority(t, tc.ca)
|
||||
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.NewOrder(w, req)
|
||||
NewOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1738,6 +1744,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandler_FinalizeOrder(t *testing.T) {
|
||||
mockMustAuthority(t, &mockCA{})
|
||||
prov := newProv()
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
|
@ -1796,15 +1803,17 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -1814,6 +1823,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -1821,9 +1831,10 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -1832,8 +1843,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -1841,21 +1853,23 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("paylod does not exist"),
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"),
|
||||
|
@ -1866,10 +1880,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
fr := &FinalizeRequest{}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
|
||||
|
@ -1878,7 +1893,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
"fail/db.GetOrder-error": func(t *testing.T) test {
|
||||
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -1893,7 +1908,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -1910,7 +1925,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/provisioner-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -1927,7 +1942,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/order-finalize-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -1952,10 +1967,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -1991,11 +2005,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.FinalizeOrder(w, req)
|
||||
FinalizeOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
|
|
@ -26,9 +26,11 @@ type revokePayload struct {
|
|||
}
|
||||
|
||||
// RevokeCert attempts to revoke a certificate.
|
||||
func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -69,7 +71,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
serial := certToBeRevoked.SerialNumber.String()
|
||||
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
|
||||
dbCert, err := db.GetCertificateBySerial(ctx, serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||
return
|
||||
|
@ -87,7 +89,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||
acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||
if acmeErr != nil {
|
||||
render.Error(w, acmeErr)
|
||||
return
|
||||
|
@ -103,7 +105,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
|
||||
ca := mustAuthority(ctx)
|
||||
hasBeenRevokedBefore, err := ca.IsRevoked(serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
||||
return
|
||||
|
@ -130,14 +133,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
||||
err = h.ca.Revoke(ctx, options)
|
||||
err = ca.Revoke(ctx, options)
|
||||
if err != nil {
|
||||
render.Error(w, wrapRevokeErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
logRevoke(w, options)
|
||||
w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index"))
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
w.Write(nil)
|
||||
}
|
||||
|
||||
|
@ -148,7 +151,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
|
||||
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
|
||||
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
|
||||
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||
func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||
if !account.IsValid() {
|
||||
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
|
||||
}
|
||||
|
|
|
@ -521,6 +521,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/no-jws": func(t *testing.T) test {
|
||||
ctx := context.Background()
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -529,6 +530,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/nil-jws": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -537,6 +539,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -544,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, nil)
|
||||
ctx = acme.NewProvisionerContext(ctx, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -553,8 +557,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -562,9 +567,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -573,9 +579,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/unmarshal-payload": func(t *testing.T) test {
|
||||
malformedPayload := []byte(`{"payload":malformed?}`)
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("error unmarshaling payload"),
|
||||
|
@ -587,10 +594,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: &acme.Error{
|
||||
|
@ -606,10 +614,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
emptyPayloadBytes, err := json.Marshal(emptyPayload)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: &acme.Error{
|
||||
|
@ -620,7 +629,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/db.GetCertificateBySerial": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
db := &acme.MockDB{
|
||||
|
@ -638,7 +647,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/different-certificate-contents": func(t *testing.T) test {
|
||||
aDifferentCert, _, err := generateCertKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
db := &acme.MockDB{
|
||||
|
@ -657,7 +666,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
db := &acme.MockDB{
|
||||
|
@ -676,7 +685,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
|
@ -697,11 +706,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-valid": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -727,11 +735,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-authorized": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -781,10 +788,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -808,11 +814,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/certificate-revoked-check-fails": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -842,7 +847,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/certificate-already-revoked": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -880,7 +885,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -918,7 +923,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
}
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -950,7 +955,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/ca.Revoke": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -982,7 +987,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/ca.Revoke-already-revoked": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -1013,11 +1018,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"ok/using-account-key": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -1041,10 +1045,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(jwsBytes))
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -1067,11 +1070,12 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
for name, setup := range tests {
|
||||
tc := setup(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
mockMustAuthority(t, tc.ca)
|
||||
req := httptest.NewRequest("POST", revokeURL, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.RevokeCert(w, req)
|
||||
RevokeCert(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1208,8 +1212,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
|
|||
for name, setup := range tests {
|
||||
tc := setup(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
|
||||
// h := &Handler{db: tc.db}
|
||||
acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
|
||||
|
||||
expectError := tc.err != nil
|
||||
gotError := acmeErr != nil
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) {
|
|||
// type using the DB interface.
|
||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||
// updated.
|
||||
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error {
|
||||
// If already valid or invalid then return without performing validation.
|
||||
if ch.Status != StatusPending {
|
||||
return nil
|
||||
}
|
||||
switch ch.Type {
|
||||
case HTTP01:
|
||||
return http01Validate(ctx, ch, db, jwk, vo)
|
||||
return http01Validate(ctx, ch, db, jwk)
|
||||
case DNS01:
|
||||
return dns01Validate(ctx, ch, db, jwk, vo)
|
||||
return dns01Validate(ctx, ch, db, jwk)
|
||||
case TLSALPN01:
|
||||
return tlsalpn01Validate(ctx, ch, db, jwk, vo)
|
||||
return tlsalpn01Validate(ctx, ch, db, jwk)
|
||||
default:
|
||||
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||
|
||||
resp, err := vo.HTTPGet(u.String())
|
||||
vc := MustClientFromContext(ctx)
|
||||
resp, err := vc.Get(u.String())
|
||||
if err != nil {
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||
"error doing http GET for url %s", u))
|
||||
|
@ -141,7 +141,7 @@ func tlsAlert(err error) uint8 {
|
|||
return 0
|
||||
}
|
||||
|
||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
config := &tls.Config{
|
||||
NextProtos: []string{"acme-tls/1"},
|
||||
// https://tools.ietf.org/html/rfc8737#section-4
|
||||
|
@ -154,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
|
||||
hostPort := net.JoinHostPort(ch.Value, "443")
|
||||
|
||||
conn, err := vo.TLSDial("tcp", hostPort, config)
|
||||
vc := MustClientFromContext(ctx)
|
||||
conn, err := vc.TLSDial("tcp", hostPort, config)
|
||||
if err != nil {
|
||||
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
|
||||
// client and server protocols. When this happens the connection is
|
||||
|
@ -253,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
|
||||
}
|
||||
|
||||
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
// Normalize domain for wildcard DNS names
|
||||
// This is done to avoid making TXT lookups for domains like
|
||||
// _acme-challenge.*.example.com
|
||||
// Instead perform txt lookup for _acme-challenge.example.com
|
||||
domain := strings.TrimPrefix(ch.Value, "*.")
|
||||
|
||||
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
|
||||
vc := MustClientFromContext(ctx)
|
||||
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
|
||||
if err != nil {
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
|
||||
"error looking up TXT records for domain %s", domain))
|
||||
|
@ -376,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type httpGetter func(string) (*http.Response, error)
|
||||
type lookupTxt func(string) ([]string, error)
|
||||
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
|
||||
// ValidateChallengeOptions are ACME challenge validator functions.
|
||||
type ValidateChallengeOptions struct {
|
||||
HTTPGet httpGetter
|
||||
LookupTxt lookupTxt
|
||||
TLSDial tlsDialer
|
||||
}
|
||||
|
|
|
@ -29,6 +29,18 @@ import (
|
|||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
get func(url string) (*http.Response, error)
|
||||
lookupTxt func(name string) ([]string, error)
|
||||
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) }
|
||||
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
|
||||
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return m.tlsDial(network, addr, config)
|
||||
}
|
||||
|
||||
func Test_storeError(t *testing.T) {
|
||||
type test struct {
|
||||
ch *Challenge
|
||||
|
@ -229,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) {
|
|||
func TestChallenge_Validate(t *testing.T) {
|
||||
type test struct {
|
||||
ch *Challenge
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
srv *httptest.Server
|
||||
|
@ -273,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -309,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -344,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -381,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -416,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
}
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -466,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -493,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
defer tc.srv.Close()
|
||||
}
|
||||
|
||||
if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -524,7 +537,7 @@ func (errReader) Close() error {
|
|||
|
||||
func TestHTTP01Validate(t *testing.T) {
|
||||
type test struct {
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
|
@ -541,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -575,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -608,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: errReader(0),
|
||||
|
@ -645,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: errReader(0),
|
||||
|
@ -681,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: errReader(0),
|
||||
}, nil
|
||||
|
@ -704,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
jwk.Key = "foo"
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||
}, nil
|
||||
|
@ -730,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||
}, nil
|
||||
|
@ -772,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||
}, nil
|
||||
|
@ -815,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
||||
}, nil
|
||||
|
@ -857,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
||||
}, nil
|
||||
|
@ -887,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -911,7 +925,7 @@ func TestDNS01Validate(t *testing.T) {
|
|||
fulldomain := "*.zap.internal"
|
||||
domain := strings.TrimPrefix(fulldomain, "*.")
|
||||
type test struct {
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
|
@ -928,8 +942,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -963,8 +977,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1001,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo"}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1026,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", "bar"}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1068,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", "bar"}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1111,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", expected}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1156,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", expected}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1186,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -1206,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error)
|
||||
|
||||
func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) {
|
||||
srv := httptest.NewUnstartedServer(http.NewServeMux())
|
||||
|
||||
|
@ -1309,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
type test struct {
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
|
@ -1321,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
ch := makeTLSCh()
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1351,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
ch := makeTLSCh()
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1384,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1413,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.Client(&noopConn{}, config), nil
|
||||
},
|
||||
},
|
||||
|
@ -1443,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.Client(&noopConn{}, config), nil
|
||||
},
|
||||
},
|
||||
|
@ -1479,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
||||
},
|
||||
},
|
||||
|
@ -1516,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
||||
},
|
||||
},
|
||||
|
@ -1562,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1605,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1649,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1692,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1736,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
srv: srv,
|
||||
jwk: jwk,
|
||||
|
@ -1758,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1797,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1841,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1884,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1924,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1963,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2008,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2054,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2100,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2144,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2189,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2226,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2253,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
defer tc.srv.Close()
|
||||
}
|
||||
|
||||
if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
|
79
acme/client.go
Normal file
79
acme/client.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client is the interface used to verify ACME challenges.
|
||||
type Client interface {
|
||||
// Get issues an HTTP GET to the specified URL.
|
||||
Get(url string) (*http.Response, error)
|
||||
|
||||
// LookupTXT returns the DNS TXT records for the given domain name.
|
||||
LookupTxt(name string) ([]string, error)
|
||||
|
||||
// TLSDial connects to the given network address using net.Dialer and then
|
||||
// initiates a TLS handshake, returning the resulting TLS connection.
|
||||
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
type clientKey struct{}
|
||||
|
||||
// NewClientContext adds the given client to the context.
|
||||
func NewClientContext(ctx context.Context, c Client) context.Context {
|
||||
return context.WithValue(ctx, clientKey{}, c)
|
||||
}
|
||||
|
||||
// ClientFromContext returns the current client from the given context.
|
||||
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
|
||||
c, ok = ctx.Value(clientKey{}).(Client)
|
||||
return
|
||||
}
|
||||
|
||||
// MustClientFromContext returns the current client from the given context. It will
|
||||
// return a new instance of the client if it does not exist.
|
||||
func MustClientFromContext(ctx context.Context) Client {
|
||||
c, ok := ClientFromContext(ctx)
|
||||
if !ok {
|
||||
return NewClient()
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type client struct {
|
||||
http *http.Client
|
||||
dialer *net.Dialer
|
||||
}
|
||||
|
||||
// NewClient returns an implementation of Client for verifying ACME challenges.
|
||||
func NewClient() Client {
|
||||
return &client{
|
||||
http: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
dialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Get(url string) (*http.Response, error) {
|
||||
return c.http.Get(url)
|
||||
}
|
||||
|
||||
func (c *client) LookupTxt(name string) ([]string, error) {
|
||||
return net.LookupTXT(name)
|
||||
}
|
||||
|
||||
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(c.dialer, network, addr, config)
|
||||
}
|
|
@ -9,15 +9,6 @@ import (
|
|||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
AreSANsAllowed(ctx context.Context, sans []string) error
|
||||
IsRevoked(sn string) (bool, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
|
||||
|
@ -28,6 +19,52 @@ func (c *Clock) Now() time.Time {
|
|||
|
||||
var clock Clock
|
||||
|
||||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
AreSANsAllowed(ctx context.Context, sans []string) error
|
||||
IsRevoked(sn string) (bool, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// NewContext adds the given acme components to the context.
|
||||
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
|
||||
ctx = NewDatabaseContext(ctx, db)
|
||||
ctx = NewClientContext(ctx, client)
|
||||
ctx = NewLinkerContext(ctx, linker)
|
||||
// Prerequisite checker is optional.
|
||||
if fn != nil {
|
||||
ctx = NewPrerequisitesCheckerContext(ctx, fn)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// PrerequisitesChecker is a function that checks if all prerequisites for
|
||||
// serving ACME are met by the CA configuration.
|
||||
type PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
|
||||
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
|
||||
// always true.
|
||||
func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type prerequisitesKey struct{}
|
||||
|
||||
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
|
||||
// context.
|
||||
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
|
||||
return context.WithValue(ctx, prerequisitesKey{}, fn)
|
||||
}
|
||||
|
||||
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
|
||||
// context.
|
||||
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
|
||||
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
|
||||
return fn, ok && fn != nil
|
||||
}
|
||||
|
||||
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
||||
// only those methods required by the ACME api/authority.
|
||||
type Provisioner interface {
|
||||
|
@ -40,6 +77,29 @@ type Provisioner interface {
|
|||
GetOptions() *provisioner.Options
|
||||
}
|
||||
|
||||
type provisionerKey struct{}
|
||||
|
||||
// NewProvisionerContext adds the given provisioner to the context.
|
||||
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
|
||||
return context.WithValue(ctx, provisionerKey{}, v)
|
||||
}
|
||||
|
||||
// ProvisionerFromContext returns the current provisioner from the given context.
|
||||
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
|
||||
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current provisioner from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustProvisionerFromContext(ctx context.Context) Provisioner {
|
||||
if v, ok := ProvisionerFromContext(ctx); !ok {
|
||||
panic("acme provisioner is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// MockProvisioner for testing
|
||||
type MockProvisioner struct {
|
||||
Mret1 interface{}
|
||||
|
|
23
acme/db.go
23
acme/db.go
|
@ -49,6 +49,29 @@ type DB interface {
|
|||
UpdateOrder(ctx context.Context, o *Order) error
|
||||
}
|
||||
|
||||
type dbKey struct{}
|
||||
|
||||
// NewDatabaseContext adds the given acme database to the context.
|
||||
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
|
||||
return context.WithValue(ctx, dbKey{}, db)
|
||||
}
|
||||
|
||||
// DatabaseFromContext returns the current acme database from the given context.
|
||||
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
|
||||
db, ok = ctx.Value(dbKey{}).(DB)
|
||||
return
|
||||
}
|
||||
|
||||
// MustDatabaseFromContext returns the current database from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustDatabaseFromContext(ctx context.Context) DB {
|
||||
if db, ok := DatabaseFromContext(ctx); !ok {
|
||||
panic("acme database is not in the context")
|
||||
} else {
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
// a mock in tests.
|
||||
type MockDB struct {
|
||||
|
|
|
@ -1,100 +1,19 @@
|
|||
package api
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
_, _, err := net.SplitHostPort(dns)
|
||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||
// these cases, then the input dns is not changed.
|
||||
lastIndex := strings.LastIndex(dns, ":")
|
||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||
if ip := net.ParseIP(hostPart); ip != nil {
|
||||
dns = "[" + hostPart + "]:" + portPart
|
||||
} else if ip := net.ParseIP(dns); ip != nil {
|
||||
dns = "[" + dns + "]"
|
||||
}
|
||||
}
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
|
||||
|
||||
LinkOrder(ctx context.Context, o *acme.Order)
|
||||
LinkAccount(ctx context.Context, o *acme.Account)
|
||||
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *acme.Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var (
|
||||
provName string
|
||||
baseURL = baseURLFromContext(ctx)
|
||||
u = url.URL{}
|
||||
)
|
||||
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
|
||||
provName = p.GetName()
|
||||
}
|
||||
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
|
||||
if baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
|
||||
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
|
||||
|
||||
// If no Scheme is set, then default to https.
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
|
||||
// If no Host is set, then use the default (first DNS attr in the ca.json).
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
u.Path = l.prefix + u.Path
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// LinkType captures the link type.
|
||||
type LinkType int
|
||||
|
||||
|
@ -160,8 +79,155 @@ func (l LinkType) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
_, _, err := net.SplitHostPort(dns)
|
||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||
// these cases, then the input dns is not changed.
|
||||
lastIndex := strings.LastIndex(dns, ":")
|
||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||
if ip := net.ParseIP(hostPart); ip != nil {
|
||||
dns = "[" + hostPart + "]:" + portPart
|
||||
} else if ip := net.ParseIP(dns); ip != nil {
|
||||
dns = "[" + dns + "]"
|
||||
}
|
||||
}
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||
Middleware(http.Handler) http.Handler
|
||||
LinkOrder(ctx context.Context, o *Order)
|
||||
LinkAccount(ctx context.Context, o *Account)
|
||||
LinkChallenge(ctx context.Context, o *Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
type linkerKey struct{}
|
||||
|
||||
// NewLinkerContext adds the given linker to the context.
|
||||
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
|
||||
return context.WithValue(ctx, linkerKey{}, v)
|
||||
}
|
||||
|
||||
// LinkerFromContext returns the current linker from the given context.
|
||||
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
|
||||
v, ok = ctx.Value(linkerKey{}).(Linker)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current linker from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustLinkerFromContext(ctx context.Context) Linker {
|
||||
if v, ok := LinkerFromContext(ctx); !ok {
|
||||
panic("acme linker is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
type baseURLKey struct{}
|
||||
|
||||
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
|
||||
var u *url.URL
|
||||
if r.Host != "" {
|
||||
u = &url.URL{Scheme: "https", Host: r.Host}
|
||||
}
|
||||
return context.WithValue(ctx, baseURLKey{}, u)
|
||||
}
|
||||
|
||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
|
||||
return u
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
// Middleware gets the provisioner and current url from the request and sets
|
||||
// them in the context so we can use the linker to create ACME links.
|
||||
func (l *linker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add base url to the context.
|
||||
ctx := newBaseURLContext(r.Context(), r)
|
||||
|
||||
// Add provisioner to the context.
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
|
||||
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit.
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var name string
|
||||
if p, ok := ProvisionerFromContext(ctx); ok {
|
||||
name = p.GetName()
|
||||
}
|
||||
|
||||
var u url.URL
|
||||
if baseURL := baseURLFromContext(ctx); baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// LinkOrder sets the ACME links required by an ACME order.
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *Order) {
|
||||
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
|
||||
for i, azID := range o.AuthorizationIDs {
|
||||
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
|
||||
|
@ -173,17 +239,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
|||
}
|
||||
|
||||
// LinkAccount sets the ACME links required by an ACME account.
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
|
||||
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
|
||||
}
|
||||
|
||||
// LinkChallenge sets the ACME links required by an ACME challenge.
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
|
||||
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
|
||||
}
|
||||
|
||||
// LinkAuthorization sets the ACME links required by an ACME authorization.
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
|
||||
for _, ch := range az.Challenges {
|
||||
l.LinkChallenge(ctx, ch, az.ID)
|
||||
}
|
|
@ -1,21 +1,38 @@
|
|||
package api
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
linker := NewLinker(dns, prefix)
|
||||
func mockProvisioner(t *testing.T) Provisioner {
|
||||
t.Helper()
|
||||
var defaultDisableRenewal = false
|
||||
|
||||
getPath := linker.GetUnescapedPathSuffix
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-<test>provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestGetUnescapedPathSuffix(t *testing.T) {
|
||||
getPath := GetUnescapedPathSuffix
|
||||
|
||||
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
|
||||
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
|
||||
|
@ -32,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLinker_DNS(t *testing.T) {
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
type test struct {
|
||||
name string
|
||||
dns string
|
||||
|
@ -117,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) {
|
|||
linker := NewLinker(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
// No provisioner and no BaseURL from request
|
||||
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
|
||||
// Provisioner: yes, BaseURL: no
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
|
||||
// Provisioner: no, BaseURL: yes
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
|
@ -163,37 +180,37 @@ func TestLinker_GetLink(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkOrder(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
oid := "orderID"
|
||||
certID := "certID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
o *acme.Order
|
||||
validate func(o *acme.Order)
|
||||
o *Order
|
||||
validate func(o *Order)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"no-authz-and-no-cert": {
|
||||
o: &acme.Order{
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{})
|
||||
assert.Equals(t, o.CertificateURL, "")
|
||||
},
|
||||
},
|
||||
"one-authz-and-cert": {
|
||||
o: &acme.Order{
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo"},
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
|
@ -202,12 +219,12 @@ func TestLinker_LinkOrder(t *testing.T) {
|
|||
},
|
||||
},
|
||||
"many-authz": {
|
||||
o: &acme.Order{
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo", "bar", "zap"},
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
|
@ -228,24 +245,24 @@ func TestLinker_LinkOrder(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkAccount(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
accID := "accountID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
a *acme.Account
|
||||
validate func(o *acme.Account)
|
||||
a *Account
|
||||
validate func(o *Account)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
a: &acme.Account{
|
||||
a: &Account{
|
||||
ID: accID,
|
||||
},
|
||||
validate: func(a *acme.Account) {
|
||||
validate: func(a *Account) {
|
||||
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
|
||||
},
|
||||
},
|
||||
|
@ -260,25 +277,25 @@ func TestLinker_LinkAccount(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkChallenge(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
chID := "chID"
|
||||
azID := "azID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
ch *acme.Challenge
|
||||
validate func(o *acme.Challenge)
|
||||
ch *Challenge
|
||||
validate func(o *Challenge)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
ch: &acme.Challenge{
|
||||
ch: &Challenge{
|
||||
ID: chID,
|
||||
},
|
||||
validate: func(ch *acme.Challenge) {
|
||||
validate: func(ch *Challenge) {
|
||||
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
|
||||
},
|
||||
},
|
||||
|
@ -293,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkAuthorization(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
chID0 := "chID-0"
|
||||
chID1 := "chID-1"
|
||||
|
@ -305,20 +322,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
|||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
az *acme.Authorization
|
||||
validate func(o *acme.Authorization)
|
||||
az *Authorization
|
||||
validate func(o *Authorization)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
az: &acme.Authorization{
|
||||
az: &Authorization{
|
||||
ID: azID,
|
||||
Challenges: []*acme.Challenge{
|
||||
Challenges: []*Challenge{
|
||||
{ID: chID0},
|
||||
{ID: chID1},
|
||||
{ID: chID2},
|
||||
},
|
||||
},
|
||||
validate: func(az *acme.Authorization) {
|
||||
validate: func(az *Authorization) {
|
||||
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
|
||||
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
|
||||
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
|
||||
|
@ -335,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
107
api/api.go
107
api/api.go
|
@ -35,7 +35,6 @@ type Authority interface {
|
|||
SSHAuthority
|
||||
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||
GetTLSOptions() *config.TLSOptions
|
||||
Root(shasum string) (*x509.Certificate, error)
|
||||
|
@ -52,6 +51,11 @@ type Authority interface {
|
|||
Version() authority.Version
|
||||
}
|
||||
|
||||
// mustAuthority will be replaced on unit tests.
|
||||
var mustAuthority = func(ctx context.Context) Authority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// TimeDuration is an alias of provisioner.TimeDuration
|
||||
type TimeDuration = provisioner.TimeDuration
|
||||
|
||||
|
@ -243,48 +247,53 @@ type caHandler struct {
|
|||
Authority Authority
|
||||
}
|
||||
|
||||
// New creates a new RouterHandler with the CA endpoints.
|
||||
func New(auth Authority) RouterHandler {
|
||||
return &caHandler{
|
||||
Authority: auth,
|
||||
}
|
||||
// Route configures the http request router.
|
||||
func (h *caHandler) Route(r Router) {
|
||||
Route(r)
|
||||
}
|
||||
|
||||
func (h *caHandler) Route(r Router) {
|
||||
r.MethodFunc("GET", "/version", h.Version)
|
||||
r.MethodFunc("GET", "/health", h.Health)
|
||||
r.MethodFunc("GET", "/root/{sha}", h.Root)
|
||||
r.MethodFunc("POST", "/sign", h.Sign)
|
||||
r.MethodFunc("POST", "/renew", h.Renew)
|
||||
r.MethodFunc("POST", "/rekey", h.Rekey)
|
||||
r.MethodFunc("POST", "/revoke", h.Revoke)
|
||||
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
||||
r.MethodFunc("GET", "/roots", h.Roots)
|
||||
r.MethodFunc("GET", "/roots.pem", h.RootsPEM)
|
||||
r.MethodFunc("GET", "/federation", h.Federation)
|
||||
// New creates a new RouterHandler with the CA endpoints.
|
||||
//
|
||||
// Deprecated: Use api.Route(r Router)
|
||||
func New(auth Authority) RouterHandler {
|
||||
return &caHandler{}
|
||||
}
|
||||
|
||||
func Route(r Router) {
|
||||
r.MethodFunc("GET", "/version", Version)
|
||||
r.MethodFunc("GET", "/health", Health)
|
||||
r.MethodFunc("GET", "/root/{sha}", Root)
|
||||
r.MethodFunc("POST", "/sign", Sign)
|
||||
r.MethodFunc("POST", "/renew", Renew)
|
||||
r.MethodFunc("POST", "/rekey", Rekey)
|
||||
r.MethodFunc("POST", "/revoke", Revoke)
|
||||
r.MethodFunc("GET", "/provisioners", Provisioners)
|
||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
|
||||
r.MethodFunc("GET", "/roots", Roots)
|
||||
r.MethodFunc("GET", "/roots.pem", RootsPEM)
|
||||
r.MethodFunc("GET", "/federation", Federation)
|
||||
// SSH CA
|
||||
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
|
||||
r.MethodFunc("POST", "/ssh/renew", h.SSHRenew)
|
||||
r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke)
|
||||
r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey)
|
||||
r.MethodFunc("GET", "/ssh/roots", h.SSHRoots)
|
||||
r.MethodFunc("GET", "/ssh/federation", h.SSHFederation)
|
||||
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
|
||||
r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts)
|
||||
r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion)
|
||||
r.MethodFunc("POST", "/ssh/sign", SSHSign)
|
||||
r.MethodFunc("POST", "/ssh/renew", SSHRenew)
|
||||
r.MethodFunc("POST", "/ssh/revoke", SSHRevoke)
|
||||
r.MethodFunc("POST", "/ssh/rekey", SSHRekey)
|
||||
r.MethodFunc("GET", "/ssh/roots", SSHRoots)
|
||||
r.MethodFunc("GET", "/ssh/federation", SSHFederation)
|
||||
r.MethodFunc("POST", "/ssh/config", SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost)
|
||||
r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts)
|
||||
r.MethodFunc("POST", "/ssh/bastion", SSHBastion)
|
||||
|
||||
// For compatibility with old code:
|
||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||
r.MethodFunc("POST", "/sign-ssh", h.SSHSign)
|
||||
r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts)
|
||||
r.MethodFunc("POST", "/re-sign", Renew)
|
||||
r.MethodFunc("POST", "/sign-ssh", SSHSign)
|
||||
r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts)
|
||||
}
|
||||
|
||||
// Version is an HTTP handler that returns the version of the server.
|
||||
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
||||
v := h.Authority.Version()
|
||||
func Version(w http.ResponseWriter, r *http.Request) {
|
||||
v := mustAuthority(r.Context()).Version()
|
||||
render.JSON(w, VersionResponse{
|
||||
Version: v.Version,
|
||||
RequireClientAuthentication: v.RequireClientAuthentication,
|
||||
|
@ -292,17 +301,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Health is an HTTP handler that returns the status of the server.
|
||||
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||
func Health(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSON(w, HealthResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
||||
// certificate for the given SHA256.
|
||||
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
||||
func Root(w http.ResponseWriter, r *http.Request) {
|
||||
sha := chi.URLParam(r, "sha")
|
||||
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
|
||||
// Load root certificate with the
|
||||
cert, err := h.Authority.Root(sum)
|
||||
cert, err := mustAuthority(r.Context()).Root(sum)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
return
|
||||
|
@ -320,18 +329,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
|||
}
|
||||
|
||||
// Provisioners returns the list of provisioners configured in the authority.
|
||||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
func Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := h.Authority.GetProvisioners(cursor, limit)
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionersResponse{
|
||||
Provisioners: p,
|
||||
NextCursor: next,
|
||||
|
@ -339,19 +349,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// ProvisionerKey returns the encrypted key of a provisioner by it's key id.
|
||||
func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
kid := chi.URLParam(r, "kid")
|
||||
key, err := h.Authority.GetEncryptedKey(kid)
|
||||
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
|
||||
if err != nil {
|
||||
render.Error(w, errs.NotFoundErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionerKeyResponse{key})
|
||||
}
|
||||
|
||||
// Roots returns all the root certificates for the CA.
|
||||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := h.Authority.GetRoots()
|
||||
func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
|
||||
return
|
||||
|
@ -368,8 +379,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// RootsPEM returns all the root certificates for the CA in PEM format.
|
||||
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := h.Authority.GetRoots()
|
||||
func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -391,8 +402,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Federation returns all the public certificates in the federation.
|
||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := h.Authority.GetFederation()
|
||||
func Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := mustAuthority(r.Context()).GetFederation()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||
return
|
||||
|
|
|
@ -171,10 +171,21 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
|
|||
return csr
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a Authority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) Authority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
type mockAuthority struct {
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
||||
authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||
getTLSOptions func() *authority.TLSOptions
|
||||
root func(shasum string) (*x509.Certificate, error)
|
||||
|
@ -203,12 +214,8 @@ type mockAuthority struct {
|
|||
|
||||
// TODO: remove once Authorize is deprecated.
|
||||
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return m.AuthorizeSign(ott)
|
||||
}
|
||||
|
||||
func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||
if m.authorizeSign != nil {
|
||||
return m.authorizeSign(ott)
|
||||
if m.authorize != nil {
|
||||
return m.authorize(ctx, ott)
|
||||
}
|
||||
return m.ret1.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
@ -789,11 +796,10 @@ func Test_caHandler_Route(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Health(t *testing.T) {
|
||||
func Test_Health(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h := New(&mockAuthority{}).(*caHandler)
|
||||
h.Health(w, req)
|
||||
Health(w, req)
|
||||
|
||||
res := w.Result()
|
||||
if res.StatusCode != 200 {
|
||||
|
@ -811,7 +817,7 @@ func Test_caHandler_Health(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Root(t *testing.T) {
|
||||
func Test_Root(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
root *x509.Certificate
|
||||
|
@ -832,9 +838,9 @@ func Test_caHandler_Root(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
|
||||
mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err})
|
||||
w := httptest.NewRecorder()
|
||||
h.Root(w, req)
|
||||
Root(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -855,7 +861,7 @@ func Test_caHandler_Root(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Sign(t *testing.T) {
|
||||
func Test_Sign(t *testing.T) {
|
||||
csr := parseCertificateRequest(csrPEM)
|
||||
valid, err := json.Marshal(SignRequest{
|
||||
CsrPEM: CertificateRequest{csr},
|
||||
|
@ -896,18 +902,18 @@ func Test_caHandler_Sign(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return tt.certAttrOpts, tt.autherr
|
||||
},
|
||||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
|
||||
w := httptest.NewRecorder()
|
||||
h.Sign(logging.NewResponseLogger(w), req)
|
||||
Sign(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -928,7 +934,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Renew(t *testing.T) {
|
||||
func Test_Renew(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1018,7 +1024,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
|
||||
|
@ -1039,12 +1045,12 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
||||
req.TLS = tt.tls
|
||||
req.Header = tt.header
|
||||
w := httptest.NewRecorder()
|
||||
h.Renew(logging.NewResponseLogger(w), req)
|
||||
Renew(logging.NewResponseLogger(w), req)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
@ -1073,7 +1079,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Rekey(t *testing.T) {
|
||||
func Test_Rekey(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1104,16 +1110,16 @@ func Test_caHandler_Rekey(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
h.Rekey(logging.NewResponseLogger(w), req)
|
||||
Rekey(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -1134,7 +1140,7 @@ func Test_caHandler_Rekey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Provisioners(t *testing.T) {
|
||||
func Test_Provisioners(t *testing.T) {
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
}
|
||||
|
@ -1200,10 +1206,8 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &caHandler{
|
||||
Authority: tt.fields.Authority,
|
||||
}
|
||||
h.Provisioners(tt.args.w, tt.args.r)
|
||||
mockMustAuthority(t, tt.fields.Authority)
|
||||
Provisioners(tt.args.w, tt.args.r)
|
||||
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
res := rec.Result()
|
||||
|
@ -1238,7 +1242,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_ProvisionerKey(t *testing.T) {
|
||||
func Test_ProvisionerKey(t *testing.T) {
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
}
|
||||
|
@ -1270,10 +1274,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &caHandler{
|
||||
Authority: tt.fields.Authority,
|
||||
}
|
||||
h.ProvisionerKey(tt.args.w, tt.args.r)
|
||||
mockMustAuthority(t, tt.fields.Authority)
|
||||
ProvisionerKey(tt.args.w, tt.args.r)
|
||||
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
res := rec.Result()
|
||||
|
@ -1298,7 +1300,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Roots(t *testing.T) {
|
||||
func Test_Roots(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1319,11 +1321,11 @@ func Test_caHandler_Roots(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
h.Roots(w, req)
|
||||
Roots(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -1360,10 +1362,10 @@ func Test_caHandler_RootsPEM(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler)
|
||||
mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err})
|
||||
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.RootsPEM(w, req)
|
||||
RootsPEM(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -1384,7 +1386,7 @@ func Test_caHandler_RootsPEM(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Federation(t *testing.T) {
|
||||
func Test_Federation(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1405,11 +1407,11 @@ func Test_caHandler_Federation(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
h.Federation(w, req)
|
||||
Federation(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
|
|
@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error {
|
|||
}
|
||||
|
||||
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
||||
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
func Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
render.Error(w, errs.BadRequest("missing client certificate"))
|
||||
return
|
||||
|
@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||
return
|
||||
|
@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
|||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
|
14
api/renew.go
14
api/renew.go
|
@ -16,14 +16,15 @@ const (
|
|||
|
||||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||
cert, err := h.getPeerCertificate(r)
|
||||
func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
cert, err := getPeerCertificate(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(cert)
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Renew(cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
|
@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
|||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
return r.TLS.PeerCertificates[0], nil
|
||||
}
|
||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
|
||||
ctx := r.Context()
|
||||
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||
}
|
||||
}
|
||||
return nil, errs.BadRequest("missing client certificate")
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) {
|
|||
// NOTE: currently only Passive revocation is supported.
|
||||
//
|
||||
// TODO: Add CRL and OCSP support.
|
||||
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body RevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
PassiveOnly: body.Passive,
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod)
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
if len(body.OTT) > 0 {
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
opts.MTLS = true
|
||||
}
|
||||
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusOK,
|
||||
auth: &mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
|
@ -152,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
statusCode: http.StatusOK,
|
||||
tls: cs,
|
||||
auth: &mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, ri *authority.RevokeOptions) error {
|
||||
|
@ -187,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusInternalServerError,
|
||||
auth: &mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
|
@ -209,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusForbidden,
|
||||
auth: &mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
|
@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
for name, _tc := range tests {
|
||||
tc := _tc(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*caHandler)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
|
||||
if tc.tls != nil {
|
||||
req.TLS = tc.tls
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
h.Revoke(logging.NewResponseLogger(w), req)
|
||||
Revoke(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
12
api/sign.go
12
api/sign.go
|
@ -49,7 +49,7 @@ type SignResponse struct {
|
|||
// Sign is an HTTP handler that reads a certificate request and an
|
||||
// one-time-token (ott) from the body and creates a new certificate with the
|
||||
// information in the certificate request.
|
||||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||
func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -68,13 +68,17 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
|||
TemplateData: body.TemplateData,
|
||||
}
|
||||
|
||||
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
||||
ctx := r.Context()
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
|
||||
return
|
||||
|
@ -89,6 +93,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
|||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
|
44
api/ssh.go
44
api/ssh.go
|
@ -250,7 +250,7 @@ type SSHBastionResponse struct {
|
|||
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHSignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -289,13 +289,15 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||
cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
|
@ -303,7 +305,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var addUserCertificate *SSHCertificate
|
||||
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
||||
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||
addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
|
@ -316,7 +318,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
if cr := body.IdentityCSR.CertificateRequest; cr != nil {
|
||||
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
|
@ -328,7 +330,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
|
||||
})
|
||||
|
||||
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||
certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||
return
|
||||
|
@ -345,8 +347,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
|
||||
// certificates.
|
||||
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHRoots(r.Context())
|
||||
func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -370,8 +373,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
||||
// for user and host certificates.
|
||||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHFederation(r.Context())
|
||||
func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -395,7 +399,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
||||
// and servers.
|
||||
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHConfigRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -406,7 +410,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
|
||||
ctx := r.Context()
|
||||
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -427,7 +432,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHCheckPrincipalRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -438,7 +443,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
|
||||
ctx := r.Context()
|
||||
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -449,13 +455,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
||||
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
var cert *x509.Certificate
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
cert = r.TLS.PeerCertificates[0]
|
||||
}
|
||||
|
||||
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
|
||||
ctx := r.Context()
|
||||
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -466,7 +473,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// SSHBastion provides returns the bastion configured if any.
|
||||
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHBastionRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -477,7 +484,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
|
||||
ctx := r.Context()
|
||||
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
|
|
@ -39,7 +39,7 @@ type SSHRekeyResponse struct {
|
|||
// SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -60,7 +60,9 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
|
@ -71,7 +73,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||
newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||
return
|
||||
|
@ -81,7 +83,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
|
|
|
@ -37,7 +37,7 @@ type SSHRenewResponse struct {
|
|||
// SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRenewRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -52,7 +52,9 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
_, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
_, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
|
@ -63,7 +65,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
|
||||
newCert, err := a.RenewSSH(ctx, oldCert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||
return
|
||||
|
@ -73,7 +75,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
|
@ -86,7 +88,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
|
||||
func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||
func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -106,7 +108,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte
|
|||
cert.NotAfter = notAfter
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(cert)
|
||||
certChain, err := mustAuthority(r.Context()).Renew(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
|||
// Revoke supports handful of different methods that revoke a Certificate.
|
||||
//
|
||||
// NOTE: currently only Passive revocation is supported.
|
||||
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHSign(t *testing.T) {
|
||||
func Test_SSHSign(t *testing.T) {
|
||||
user, err := getSignedUserCertificate()
|
||||
assert.FatalError(t, err)
|
||||
host, err := getSignedHostCertificate()
|
||||
|
@ -315,8 +315,8 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return []provisioner.SignOption{}, tt.authErr
|
||||
},
|
||||
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
|
@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
|||
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
return tt.tlsSignCerts, tt.tlsSignErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHSign(logging.NewResponseLogger(w), req)
|
||||
SSHSign(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHRoots(t *testing.T) {
|
||||
func Test_SSHRoots(t *testing.T) {
|
||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
|
@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHRoots(logging.NewResponseLogger(w), req)
|
||||
SSHRoots(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHFederation(t *testing.T) {
|
||||
func Test_SSHFederation(t *testing.T) {
|
||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
|
@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHFederation(logging.NewResponseLogger(w), req)
|
||||
SSHFederation(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHConfig(t *testing.T) {
|
||||
func Test_SSHConfig(t *testing.T) {
|
||||
userOutput := []templates.Output{
|
||||
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
|
||||
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
|
||||
|
@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||
return tt.output, tt.err
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHConfig(logging.NewResponseLogger(w), req)
|
||||
SSHConfig(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||
func Test_SSHCheckHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req string
|
||||
|
@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
|
||||
return tt.exists, tt.err
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||
SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHGetHosts(t *testing.T) {
|
||||
func Test_SSHGetHosts(t *testing.T) {
|
||||
hosts := []authority.Host{
|
||||
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
|
||||
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
|
||||
|
@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
|
||||
return tt.hosts, tt.err
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||
SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHBastion(t *testing.T) {
|
||||
func Test_SSHBastion(t *testing.T) {
|
||||
bastion := &authority.Bastion{
|
||||
Hostname: "bastion.local",
|
||||
}
|
||||
|
@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||
return tt.bastion, tt.bastionErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHBastion(logging.NewResponseLogger(w), req)
|
||||
SSHBastion(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
|
|
@ -33,7 +33,7 @@ type GetExternalAccountKeysResponse struct {
|
|||
|
||||
// requireEABEnabled is a middleware that ensures ACME EAB is enabled
|
||||
// before serving requests that act on ACME EAB credentials.
|
||||
func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
|
@ -53,32 +53,33 @@ func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
type acmeAdminResponderInterface interface {
|
||||
// ACMEAdminResponder is responsible for writing ACME admin responses
|
||||
type ACMEAdminResponder interface {
|
||||
GetExternalAccountKeys(w http.ResponseWriter, r *http.Request)
|
||||
CreateExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
||||
DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// ACMEAdminResponder is responsible for writing ACME admin responses
|
||||
type ACMEAdminResponder struct{}
|
||||
// acmeAdminResponder implements ACMEAdminResponder.
|
||||
type acmeAdminResponder struct{}
|
||||
|
||||
// NewACMEAdminResponder returns a new ACMEAdminResponder
|
||||
func NewACMEAdminResponder() *ACMEAdminResponder {
|
||||
return &ACMEAdminResponder{}
|
||||
func NewACMEAdminResponder() ACMEAdminResponder {
|
||||
return &acmeAdminResponder{}
|
||||
}
|
||||
|
||||
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
|
||||
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
|
||||
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
|
||||
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
|
|
|
@ -33,6 +33,17 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
|||
return protojson.Unmarshal(data, m)
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a adminAuthority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) adminAuthority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_requireEABEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
|
@ -117,12 +128,9 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.requireEABEnabled(tc.next)(w, req)
|
||||
requireEABEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
|
@ -85,10 +85,10 @@ type DeleteResponse struct {
|
|||
}
|
||||
|
||||
// GetAdmin returns the requested admin, or an error.
|
||||
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
adm, ok := h.auth.LoadAdminByID(id)
|
||||
adm, ok := mustAuthority(r.Context()).LoadAdminByID(id)
|
||||
if !ok {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found", id))
|
||||
|
@ -98,7 +98,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// GetAdmins returns a segment of admins associated with the authority.
|
||||
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
func GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
|
@ -106,7 +106,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
|
||||
admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
||||
return
|
||||
|
@ -118,7 +118,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// CreateAdmin creates a new admin.
|
||||
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body CreateAdminRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
|
@ -130,7 +130,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
|
||||
auth := mustAuthority(r.Context())
|
||||
p, err := auth.LoadProvisionerByName(body.Provisioner)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
||||
return
|
||||
|
@ -141,7 +142,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
Type: body.Type,
|
||||
}
|
||||
// Store to authority collection.
|
||||
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||
if err := auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
|
||||
return
|
||||
}
|
||||
|
@ -150,10 +151,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// DeleteAdmin deletes admin.
|
||||
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
|
||||
if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
||||
return
|
||||
}
|
||||
|
@ -162,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// UpdateAdmin updates an existing admin.
|
||||
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body UpdateAdminRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
|
@ -175,8 +176,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||
auth := mustAuthority(r.Context())
|
||||
adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
||||
return
|
||||
|
|
|
@ -352,14 +352,11 @@ func TestHandler_GetAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAdmin(w, req)
|
||||
GetAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -491,13 +488,10 @@ func TestHandler_GetAdmins(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAdmins(w, req)
|
||||
GetAdmins(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -675,13 +669,11 @@ func TestHandler_CreateAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateAdmin(w, req)
|
||||
CreateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -767,13 +759,11 @@ func TestHandler_DeleteAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteAdmin(w, req)
|
||||
DeleteAdmin(w, req)
|
||||
res := w.Result()
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
|
@ -912,13 +902,11 @@ func TestHandler_UpdateAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateAdmin(w, req)
|
||||
UpdateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
|
@ -1,50 +1,58 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
// Handler is the Admin API request handler.
|
||||
type Handler struct {
|
||||
adminDB admin.DB
|
||||
auth adminAuthority
|
||||
acmeDB acme.DB
|
||||
acmeResponder acmeAdminResponderInterface
|
||||
policyResponder policyAdminResponderInterface
|
||||
acmeResponder ACMEAdminResponder
|
||||
policyResponder PolicyAdminResponder
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
//
|
||||
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
Route(r, h.acmeResponder, h.policyResponder)
|
||||
}
|
||||
|
||||
// NewHandler returns a new Authority Config Handler.
|
||||
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler {
|
||||
//
|
||||
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
|
||||
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) api.RouterHandler {
|
||||
return &Handler{
|
||||
auth: auth,
|
||||
adminDB: adminDB,
|
||||
acmeDB: acmeDB,
|
||||
acmeResponder: acmeResponder,
|
||||
policyResponder: policyResponder,
|
||||
}
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
var mustAuthority = func(ctx context.Context) adminAuthority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) {
|
||||
authnz := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next))
|
||||
return extractAuthorizeTokenAdmin(requireAPIEnabled(next))
|
||||
}
|
||||
|
||||
enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return h.checkAction(next, true)
|
||||
return checkAction(next, true)
|
||||
}
|
||||
|
||||
disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return h.checkAction(next, false)
|
||||
return checkAction(next, false)
|
||||
}
|
||||
|
||||
acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(h.loadProvisionerByName(h.requireEABEnabled(next)))
|
||||
return authnz(loadProvisionerByName(requireEABEnabled(next)))
|
||||
}
|
||||
|
||||
authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
@ -52,53 +60,58 @@ func (h *Handler) Route(r api.Router) {
|
|||
}
|
||||
|
||||
provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(disabledInStandalone(h.loadProvisionerByName(next)))
|
||||
return authnz(disabledInStandalone(loadProvisionerByName(next)))
|
||||
}
|
||||
|
||||
acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(disabledInStandalone(h.loadProvisionerByName(h.requireEABEnabled(h.loadExternalAccountKey(next)))))
|
||||
return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next)))))
|
||||
}
|
||||
|
||||
// Provisioners
|
||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners))
|
||||
r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner))
|
||||
r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner))
|
||||
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners))
|
||||
r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner))
|
||||
r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner))
|
||||
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner))
|
||||
|
||||
// Admins
|
||||
r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin))
|
||||
r.MethodFunc("GET", "/admins", authnz(h.GetAdmins))
|
||||
r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin))
|
||||
r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin))
|
||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin))
|
||||
r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin))
|
||||
r.MethodFunc("GET", "/admins", authnz(GetAdmins))
|
||||
r.MethodFunc("POST", "/admins", authnz(CreateAdmin))
|
||||
r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin))
|
||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin))
|
||||
|
||||
// ACME responder
|
||||
if acmeResponder != nil {
|
||||
// ACME External Account Binding Keys
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys))
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys))
|
||||
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.CreateExternalAccountKey))
|
||||
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(h.acmeResponder.DeleteExternalAccountKey))
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
|
||||
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey))
|
||||
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey))
|
||||
}
|
||||
|
||||
// Policy responder
|
||||
if policyResponder != nil {
|
||||
// Policy - Authority
|
||||
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(h.policyResponder.GetAuthorityPolicy))
|
||||
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(h.policyResponder.CreateAuthorityPolicy))
|
||||
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(h.policyResponder.UpdateAuthorityPolicy))
|
||||
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(h.policyResponder.DeleteAuthorityPolicy))
|
||||
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy))
|
||||
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy))
|
||||
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy))
|
||||
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy))
|
||||
|
||||
// Policy - Provisioner
|
||||
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.GetProvisionerPolicy))
|
||||
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.CreateProvisionerPolicy))
|
||||
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.UpdateProvisionerPolicy))
|
||||
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.DeleteProvisionerPolicy))
|
||||
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy))
|
||||
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy))
|
||||
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy))
|
||||
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy))
|
||||
|
||||
// Policy - ACME Account
|
||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy))
|
||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy))
|
||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy))
|
||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy))
|
||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy))
|
||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy))
|
||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy))
|
||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy))
|
||||
|
||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
|
||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
|
||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
|
||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
|
||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
|
||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
|
||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
|
||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,11 +17,10 @@ import (
|
|||
|
||||
// requireAPIEnabled is a middleware that ensures the Administration API
|
||||
// is enabled before servicing requests.
|
||||
func (h *Handler) requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.auth.IsAdminAPIEnabled() {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"administration API not enabled"))
|
||||
if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
|
@ -29,7 +28,7 @@ func (h *Handler) requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
|||
}
|
||||
|
||||
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
|
||||
func (h *Handler) extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||||
func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
tok := r.Header.Get("Authorization")
|
||||
|
@ -39,36 +38,39 @@ func (h *Handler) extractAuthorizeTokenAdmin(next http.HandlerFunc) http.Handler
|
|||
return
|
||||
}
|
||||
|
||||
adm, err := h.auth.AuthorizeAdminToken(r, tok)
|
||||
ctx := r.Context()
|
||||
adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := linkedca.NewContextWithAdmin(r.Context(), adm)
|
||||
ctx = linkedca.NewContextWithAdmin(ctx, adm)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// loadProvisionerByName is a middleware that searches for a provisioner
|
||||
// by name and stores it in the context.
|
||||
func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
|
||||
func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
name := chi.URLParam(r, "provisionerName")
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
|
||||
ctx := r.Context()
|
||||
auth := mustAuthority(ctx)
|
||||
adminDB := admin.MustFromContext(ctx)
|
||||
name := chi.URLParam(r, "provisionerName")
|
||||
|
||||
// TODO(hs): distinguish 404 vs. 500
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
prov, err := adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
|
||||
return
|
||||
|
@ -80,9 +82,8 @@ func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc
|
|||
}
|
||||
|
||||
// checkAction checks if an action is supported in standalone or not
|
||||
func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
|
||||
func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// actions allowed in standalone mode are always supported
|
||||
if supportedInStandalone {
|
||||
next(w, r)
|
||||
|
@ -91,7 +92,7 @@ func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool)
|
|||
|
||||
// when an action is not supported in standalone mode and when
|
||||
// using a nosql.DB backend, actions are not supported
|
||||
if _, ok := h.adminDB.(*nosql.DB); ok {
|
||||
if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"operation not supported in standalone mode"))
|
||||
return
|
||||
|
@ -104,10 +105,11 @@ func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool)
|
|||
|
||||
// loadExternalAccountKey is a middleware that searches for an ACME
|
||||
// External Account Key by reference or keyID and stores it in the context.
|
||||
func (h *Handler) loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
|
||||
func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
reference := chi.URLParam(r, "reference")
|
||||
keyID := chi.URLParam(r, "keyID")
|
||||
|
@ -118,9 +120,9 @@ func (h *Handler) loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc
|
|||
)
|
||||
|
||||
if keyID != "" {
|
||||
eak, err = h.acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
|
||||
eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
|
||||
} else {
|
||||
eak, err = h.acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
|
||||
eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -71,13 +71,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.requireAPIEnabled(tc.next)(w, req)
|
||||
requireAPIEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -196,13 +194,10 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||
extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -251,6 +246,7 @@ func TestHandler_loadProvisionerByName(t *testing.T) {
|
|||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: err,
|
||||
}
|
||||
|
@ -326,16 +322,13 @@ func TestHandler_loadProvisionerByName(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
h.loadProvisionerByName(tc.next)(w, req)
|
||||
loadProvisionerByName(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -405,14 +398,10 @@ func TestHandler_checkAction(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
ctx := admin.NewContext(context.Background(), tc.adminDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.checkAction(tc.next, tc.supportedInStandalone)(w, req)
|
||||
checkAction(tc.next, tc.supportedInStandalone)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -653,14 +642,11 @@ func TestHandler_loadExternalAccountKey(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
acmeDB: tc.acmeDB,
|
||||
}
|
||||
|
||||
ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.loadExternalAccountKey(tc.next)(w, req)
|
||||
loadExternalAccountKey(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
|
@ -14,7 +15,9 @@ import (
|
|||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
type policyAdminResponderInterface interface {
|
||||
// PolicyAdminResponder is the interface responsible for writing ACME admin
|
||||
// responses.
|
||||
type PolicyAdminResponder interface {
|
||||
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
|
@ -29,39 +32,24 @@ type policyAdminResponderInterface interface {
|
|||
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// PolicyAdminResponder is responsible for writing ACME admin responses
|
||||
type PolicyAdminResponder struct {
|
||||
auth adminAuthority
|
||||
adminDB admin.DB
|
||||
acmeDB acme.DB
|
||||
isLinkedCA bool
|
||||
}
|
||||
// policyAdminResponder implements PolicyAdminResponder.
|
||||
type policyAdminResponder struct{}
|
||||
|
||||
// NewACMEAdminResponder returns a new ACMEAdminResponder
|
||||
func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB) *PolicyAdminResponder {
|
||||
|
||||
var isLinkedCA bool
|
||||
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok {
|
||||
isLinkedCA = a.IsLinkedCA()
|
||||
}
|
||||
|
||||
return &PolicyAdminResponder{
|
||||
auth: auth,
|
||||
adminDB: adminDB,
|
||||
acmeDB: acmeDB,
|
||||
isLinkedCA: isLinkedCA,
|
||||
}
|
||||
// NewACMEAdminResponder returns a new PolicyAdminResponder.
|
||||
func NewPolicyAdminResponder() PolicyAdminResponder {
|
||||
return &policyAdminResponder{}
|
||||
}
|
||||
|
||||
// GetAuthorityPolicy handles the GET /admin/authority/policy request
|
||||
func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
authorityPolicy, err := par.auth.GetAuthorityPolicy(r.Context())
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(r.Context())
|
||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
return
|
||||
|
@ -76,15 +64,15 @@ func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht
|
|||
}
|
||||
|
||||
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
|
||||
func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx)
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
|
@ -113,7 +101,7 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
|
|||
adm := linkedca.MustAdminFromContext(ctx)
|
||||
|
||||
var createdPolicy *linkedca.Policy
|
||||
if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
|
||||
return
|
||||
|
@ -127,15 +115,15 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
|
|||
}
|
||||
|
||||
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
|
||||
func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx)
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
|
@ -163,7 +151,7 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
|
|||
adm := linkedca.MustAdminFromContext(ctx)
|
||||
|
||||
var updatedPolicy *linkedca.Policy
|
||||
if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
|
||||
return
|
||||
|
@ -177,15 +165,15 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
|
|||
}
|
||||
|
||||
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
|
||||
func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx)
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
|
@ -197,7 +185,7 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r
|
|||
return
|
||||
}
|
||||
|
||||
if err := par.auth.RemoveAuthorityPolicy(ctx); err != nil {
|
||||
if err := auth.RemoveAuthorityPolicy(ctx); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
|
||||
return
|
||||
}
|
||||
|
@ -206,15 +194,14 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r
|
|||
}
|
||||
|
||||
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
|
||||
func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(r.Context())
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
|
@ -225,16 +212,14 @@ func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *
|
|||
}
|
||||
|
||||
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
|
||||
func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
|
||||
|
@ -256,8 +241,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
|
|||
}
|
||||
|
||||
prov.Policy = newPolicy
|
||||
|
||||
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
|
||||
return
|
||||
|
@ -271,16 +256,14 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
|
|||
}
|
||||
|
||||
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
|
||||
func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
|
@ -301,7 +284,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
|
|||
}
|
||||
|
||||
prov.Policy = newPolicy
|
||||
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
|
||||
return
|
||||
|
@ -315,16 +299,14 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
|
|||
}
|
||||
|
||||
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
|
||||
func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
|
||||
if prov.Policy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
|
@ -333,7 +315,8 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter,
|
|||
// remove the policy
|
||||
prov.Policy = nil
|
||||
|
||||
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
@ -341,16 +324,14 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter,
|
|||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
|
@ -360,17 +341,15 @@ func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *
|
|||
render.ProtoJSONStatus(w, eakPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
|
||||
|
@ -394,7 +373,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
|
|||
eak.Policy = newPolicy
|
||||
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
@ -402,17 +382,15 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
|
|||
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
|
@ -434,7 +412,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
|
|||
|
||||
eak.Policy = newPolicy
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
@ -442,17 +421,15 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
|
|||
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := par.blockLinkedCA(); err != nil {
|
||||
func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
|
@ -463,7 +440,8 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
|
|||
eak.Policy = nil
|
||||
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
@ -472,9 +450,10 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
|
|||
}
|
||||
|
||||
// blockLinkedCA blocks all API operations on linked deployments
|
||||
func (par *PolicyAdminResponder) blockLinkedCA() error {
|
||||
func blockLinkedCA(ctx context.Context) error {
|
||||
// temporary blocking linked deployments
|
||||
if par.isLinkedCA {
|
||||
adminDB := admin.MustFromContext(ctx)
|
||||
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() {
|
||||
return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments")
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -110,6 +110,7 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
|
|||
err.Message = "error retrieving authority policy: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
||||
|
@ -125,6 +126,7 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
|
|||
err.Message = "authority policy does not exist"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||
|
@ -180,6 +182,7 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
|
|||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return policy, nil
|
||||
|
@ -234,11 +237,12 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.GetAuthorityPolicy(w, req)
|
||||
|
@ -302,6 +306,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
|||
err.Message = "error retrieving authority policy: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
||||
|
@ -317,6 +322,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
|||
err.Message = "authority already has a policy"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return &linkedca.Policy{}, nil
|
||||
|
@ -333,6 +339,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
|||
body := []byte("{?}")
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||
|
@ -359,6 +366,7 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
|||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||
|
@ -509,11 +517,13 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.CreateAuthorityPolicy(w, req)
|
||||
|
@ -587,6 +597,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
|||
err.Message = "error retrieving authority policy: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
||||
|
@ -603,6 +614,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
|||
err.Status = http.StatusNotFound
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, nil
|
||||
|
@ -626,6 +638,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
|||
body := []byte("{?}")
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return policy, nil
|
||||
|
@ -659,6 +672,7 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
|||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return policy, nil
|
||||
|
@ -809,11 +823,13 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.UpdateAuthorityPolicy(w, req)
|
||||
|
@ -887,6 +903,7 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
|||
err.Message = "error retrieving authority policy: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
||||
|
@ -903,6 +920,7 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
|||
err.Status = http.StatusNotFound
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, nil
|
||||
|
@ -925,6 +943,7 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
|||
err.Message = "error deleting authority policy: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return policy, nil
|
||||
|
@ -948,6 +967,7 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return policy, nil
|
||||
|
@ -963,11 +983,13 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.DeleteAuthorityPolicy(w, req)
|
||||
|
@ -1033,6 +1055,7 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) {
|
|||
err.Message = "provisioner policy does not exist"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
|
@ -1086,6 +1109,7 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) {
|
|||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
response: &testPolicyResponse{
|
||||
X509: &testX509Policy{
|
||||
Allow: &testX509Names{
|
||||
|
@ -1135,11 +1159,13 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.GetProvisionerPolicy(w, req)
|
||||
|
@ -1214,6 +1240,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
|||
err.Message = "provisioner provName already has a policy"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
err: err,
|
||||
statusCode: 409,
|
||||
}
|
||||
|
@ -1228,6 +1255,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
|||
body := []byte("{?}")
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
|
@ -1252,6 +1280,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
|||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||
|
@ -1284,6 +1313,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return &authority.PolicyError{
|
||||
|
@ -1319,6 +1349,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return &authority.PolicyError{
|
||||
|
@ -1352,6 +1383,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return nil
|
||||
|
@ -1372,11 +1404,12 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.CreateProvisionerPolicy(w, req)
|
||||
|
@ -1452,6 +1485,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
|||
err.Message = "provisioner policy does not exist"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
|
@ -1474,6 +1508,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
|||
body := []byte("{?}")
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
|
@ -1506,6 +1541,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
|||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||
|
@ -1539,6 +1575,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return &authority.PolicyError{
|
||||
|
@ -1575,6 +1612,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return &authority.PolicyError{
|
||||
|
@ -1609,6 +1647,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return nil
|
||||
|
@ -1629,11 +1668,12 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.UpdateProvisionerPolicy(w, req)
|
||||
|
@ -1710,6 +1750,7 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
|
|||
err.Message = "provisioner policy does not exist"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
|
@ -1724,6 +1765,7 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
|
|||
err.Message = "error deleting provisioner policy: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return errors.New("force")
|
||||
|
@ -1741,6 +1783,7 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
|
|||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return nil
|
||||
|
@ -1753,11 +1796,13 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.DeleteProvisionerPolicy(w, req)
|
||||
|
@ -1828,6 +1873,7 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) {
|
|||
err.Message = "ACME EAK policy does not exist"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
|
@ -1886,6 +1932,7 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) {
|
|||
ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
response: &testPolicyResponse{
|
||||
X509: &testX509Policy{
|
||||
Allow: &testX509Names{
|
||||
|
@ -1935,11 +1982,12 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.GetACMEAccountPolicy(w, req)
|
||||
|
@ -2018,6 +2066,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
|||
err.Message = "ACME EAK eakID already has a policy"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
err: err,
|
||||
statusCode: 409,
|
||||
}
|
||||
|
@ -2036,6 +2085,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
|||
body := []byte("{?}")
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
|
@ -2064,6 +2114,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
|||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
|
@ -2092,6 +2143,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
acmeDB: &acme.MockDB{
|
||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
assert.Equal(t, "provID", provisionerID)
|
||||
|
@ -2125,6 +2177,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
acmeDB: &acme.MockDB{
|
||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
assert.Equal(t, "provID", provisionerID)
|
||||
|
@ -2147,11 +2200,12 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.CreateACMEAccountPolicy(w, req)
|
||||
|
@ -2231,6 +2285,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
|||
err.Message = "ACME EAK policy does not exist"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
|
@ -2257,6 +2312,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
|||
body := []byte("{?}")
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
|
@ -2293,6 +2349,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
|||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
|
@ -2322,6 +2379,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
acmeDB: &acme.MockDB{
|
||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
assert.Equal(t, "provID", provisionerID)
|
||||
|
@ -2356,6 +2414,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
acmeDB: &acme.MockDB{
|
||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
assert.Equal(t, "provID", provisionerID)
|
||||
|
@ -2378,11 +2437,12 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.UpdateACMEAccountPolicy(w, req)
|
||||
|
@ -2462,6 +2522,7 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
|
|||
err.Message = "ACME EAK policy does not exist"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
|
@ -2488,6 +2549,7 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
|
|||
err.Message = "error deleting ACME EAK policy: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
acmeDB: &acme.MockDB{
|
||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
assert.Equal(t, "provID", provisionerID)
|
||||
|
@ -2519,6 +2581,7 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
|
|||
ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
acmeDB: &acme.MockDB{
|
||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
assert.Equal(t, "provID", provisionerID)
|
||||
|
@ -2533,11 +2596,12 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
|
||||
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||
par := NewPolicyAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
par.DeleteACMEAccountPolicy(w, req)
|
||||
|
|
|
@ -23,29 +23,31 @@ type GetProvisionersResponse struct {
|
|||
}
|
||||
|
||||
// GetProvisioner returns the requested provisioner, or an error.
|
||||
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
func GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
|
||||
ctx := r.Context()
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
auth := mustAuthority(ctx)
|
||||
db := admin.MustFromContext(ctx)
|
||||
|
||||
if len(id) > 0 {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
prov, err := db.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// GetProvisioners returns the given segment of provisioners associated with the authority.
|
||||
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
func GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
|
@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
p, next, err := h.auth.GetProvisioners(cursor, limit)
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// CreateProvisioner creates a new prov.
|
||||
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var prov = new(linkedca.Provisioner)
|
||||
if err := read.ProtoJSON(r.Body, prov); err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
|
||||
if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
||||
return
|
||||
}
|
||||
|
@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// DeleteProvisioner deletes a provisioner.
|
||||
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
auth := mustAuthority(r.Context())
|
||||
|
||||
if len(id) > 0 {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||
if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
||||
return
|
||||
}
|
||||
|
@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// UpdateProvisioner updates an existing prov.
|
||||
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var nu = new(linkedca.Provisioner)
|
||||
if err := read.ProtoJSON(r.Body, nu); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
name := chi.URLParam(r, "name")
|
||||
_old, err := h.auth.LoadProvisionerByName(name)
|
||||
auth := mustAuthority(ctx)
|
||||
db := admin.MustFromContext(ctx)
|
||||
|
||||
p, err := auth.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
||||
return
|
||||
}
|
||||
|
||||
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
|
||||
old, err := db.GetProvisioner(r.Context(), p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||
if err := auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||
ctx: ctx,
|
||||
req: req,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
|
@ -74,6 +75,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||
ctx: ctx,
|
||||
req: req,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
|
@ -156,13 +158,11 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := tc.req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetProvisioner(w, req)
|
||||
GetProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -280,12 +280,10 @@ func TestHandler_GetProvisioners(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetProvisioners(w, req)
|
||||
GetProvisioners(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -405,13 +403,11 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateProvisioner(w, req)
|
||||
CreateProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -571,12 +567,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteProvisioner(w, req)
|
||||
DeleteProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -625,6 +619,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: "badRequest",
|
||||
|
@ -654,6 +649,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: auth,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
|
@ -1061,14 +1057,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateProvisioner(w, req)
|
||||
UpdateProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
|
@ -76,6 +76,29 @@ type DB interface {
|
|||
DeleteAuthorityPolicy(ctx context.Context) error
|
||||
}
|
||||
|
||||
type dbKey struct{}
|
||||
|
||||
// NewContext adds the given admin database to the context.
|
||||
func NewContext(ctx context.Context, db DB) context.Context {
|
||||
return context.WithValue(ctx, dbKey{}, db)
|
||||
}
|
||||
|
||||
// FromContext returns the current admin database from the given context.
|
||||
func FromContext(ctx context.Context) (db DB, ok bool) {
|
||||
db, ok = ctx.Value(dbKey{}).(DB)
|
||||
return
|
||||
}
|
||||
|
||||
// MustFromContext returns the current admin database from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustFromContext(ctx context.Context) DB {
|
||||
if db, ok := FromContext(ctx); !ok {
|
||||
panic("admin database is not in the context")
|
||||
} else {
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
// a mock in tests.
|
||||
type MockDB struct {
|
||||
|
|
|
@ -167,6 +167,29 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
|
|||
return a, nil
|
||||
}
|
||||
|
||||
type authorityKey struct{}
|
||||
|
||||
// NewContext adds the given authority to the context.
|
||||
func NewContext(ctx context.Context, a *Authority) context.Context {
|
||||
return context.WithValue(ctx, authorityKey{}, a)
|
||||
}
|
||||
|
||||
// FromContext returns the current authority from the given context.
|
||||
func FromContext(ctx context.Context) (a *Authority, ok bool) {
|
||||
a, ok = ctx.Value(authorityKey{}).(*Authority)
|
||||
return
|
||||
}
|
||||
|
||||
// MustFromContext returns the current authority from the given context. It will
|
||||
// panic if the authority is not in the context.
|
||||
func MustFromContext(ctx context.Context) *Authority {
|
||||
if a, ok := FromContext(ctx); !ok {
|
||||
panic("authority is not in the context")
|
||||
} else {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadAdminResources reloads admins and provisioners from the DB.
|
||||
func (a *Authority) ReloadAdminResources(ctx context.Context) error {
|
||||
var (
|
||||
|
@ -235,6 +258,7 @@ func (a *Authority) init() error {
|
|||
}
|
||||
|
||||
var err error
|
||||
ctx := NewContext(context.Background(), a)
|
||||
|
||||
// Set password if they are not set.
|
||||
var configPassword []byte
|
||||
|
@ -270,7 +294,7 @@ func (a *Authority) init() error {
|
|||
if a.config.KMS != nil {
|
||||
options = *a.config.KMS
|
||||
}
|
||||
a.keyManager, err = kms.New(context.Background(), options)
|
||||
a.keyManager, err = kms.New(ctx, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -300,7 +324,7 @@ func (a *Authority) init() error {
|
|||
|
||||
// Configure linked RA
|
||||
if linkedcaClient != nil && options.CertificateAuthority == "" {
|
||||
conf, err := linkedcaClient.GetConfiguration(context.Background())
|
||||
conf, err := linkedcaClient.GetConfiguration(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -334,7 +358,7 @@ func (a *Authority) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
a.x509CAService, err = cas.New(context.Background(), options)
|
||||
a.x509CAService, err = cas.New(ctx, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -521,7 +545,7 @@ func (a *Authority) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
a.scepService, err = scep.NewService(context.Background(), options)
|
||||
a.scepService, err = scep.NewService(ctx, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -543,19 +567,19 @@ func (a *Authority) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
provs, err := a.adminDB.GetProvisioners(context.Background())
|
||||
provs, err := a.adminDB.GetProvisioners(ctx)
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
||||
}
|
||||
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
|
||||
// Create First Provisioner
|
||||
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password))
|
||||
prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password))
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating first provisioner")
|
||||
}
|
||||
|
||||
// Create first admin
|
||||
if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{
|
||||
if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{
|
||||
ProvisionerId: prov.Id,
|
||||
Subject: "step",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
|
@ -571,7 +595,7 @@ func (a *Authority) init() error {
|
|||
}
|
||||
|
||||
// Load x509 and SSH Policy Engines
|
||||
if err := a.reloadPolicyEngines(context.Background()); err != nil {
|
||||
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -596,6 +620,15 @@ func (a *Authority) init() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetID returns the define authority id or a zero uuid.
|
||||
func (a *Authority) GetID() string {
|
||||
const zeroUUID = "00000000-0000-0000-0000-000000000000"
|
||||
if id := a.config.AuthorityConfig.AuthorityID; id != "" {
|
||||
return id
|
||||
}
|
||||
return zeroUUID
|
||||
}
|
||||
|
||||
// GetDatabase returns the authority database. If the configuration does not
|
||||
// define a database, GetDatabase will return a db.SimpleDB instance.
|
||||
func (a *Authority) GetDatabase() db.AuthDB {
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"go.step.sm/crypto/jose"
|
||||
|
@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_GetID(t *testing.T) {
|
||||
type fields struct {
|
||||
authorityID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{"ok", fields{""}, "00000000-0000-0000-0000-000000000000"},
|
||||
{"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Authority{
|
||||
config: &config.Config{
|
||||
AuthorityConfig: &config.AuthConfig{
|
||||
AuthorityID: tt.fields.authorityID,
|
||||
},
|
||||
},
|
||||
}
|
||||
if got := a.GetID(); got != tt.want {
|
||||
t.Errorf("Authority.GetID() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -260,8 +260,7 @@ func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisio
|
|||
// AuthorizeSign authorizes a signature request by validating and authenticating
|
||||
// a token that must be sent w/ the request.
|
||||
//
|
||||
// NOTE: This method is deprecated and should not be used. We make it available
|
||||
// in the short term os as not to break existing clients.
|
||||
// Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error).
|
||||
func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
return a.Authorize(ctx, token)
|
||||
|
|
|
@ -54,7 +54,11 @@ func startCABootstrapServer() *httptest.Server {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
baseContext := buildContext(ca.auth, nil, nil, nil)
|
||||
srv.Config.Handler = ca.srv.Handler
|
||||
srv.Config.BaseContext = func(net.Listener) context.Context {
|
||||
return baseContext
|
||||
}
|
||||
srv.TLS = ca.srv.TLSConfig
|
||||
srv.StartTLS()
|
||||
// Force the use of GetCertificate on IPs
|
||||
|
|
74
ca/ca.go
74
ca/ca.go
|
@ -1,10 +1,12 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
@ -18,6 +20,7 @@ import (
|
|||
acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
adminAPI "github.com/smallstep/certificates/authority/admin/api"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/db"
|
||||
|
@ -170,10 +173,9 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
insecureHandler := http.Handler(insecureMux)
|
||||
|
||||
// Add regular CA api endpoints in / and /1.0
|
||||
routerHandler := api.New(auth)
|
||||
routerHandler.Route(mux)
|
||||
api.Route(mux)
|
||||
mux.Route("/1.0", func(r chi.Router) {
|
||||
routerHandler.Route(r)
|
||||
api.Route(r)
|
||||
})
|
||||
|
||||
//Add ACME api endpoints in /acme and /1.0/acme
|
||||
|
@ -187,49 +189,41 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
dns = fmt.Sprintf("%s:%s", dns, port)
|
||||
}
|
||||
|
||||
// ACME Router
|
||||
prefix := "acme"
|
||||
// ACME Router is only available if we have a database.
|
||||
var acmeDB acme.DB
|
||||
if cfg.DB == nil {
|
||||
acmeDB = nil
|
||||
} else {
|
||||
var acmeLinker acme.Linker
|
||||
if cfg.DB != nil {
|
||||
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error configuring ACME DB interface")
|
||||
}
|
||||
}
|
||||
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
|
||||
Backdate: *cfg.AuthorityConfig.Backdate,
|
||||
DB: acmeDB,
|
||||
DNS: dns,
|
||||
Prefix: prefix,
|
||||
CA: auth,
|
||||
})
|
||||
mux.Route("/"+prefix, func(r chi.Router) {
|
||||
acmeHandler.Route(r)
|
||||
acmeLinker = acme.NewLinker(dns, "acme")
|
||||
mux.Route("/acme", func(r chi.Router) {
|
||||
acmeAPI.Route(r)
|
||||
})
|
||||
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
|
||||
// of the ACME spec.
|
||||
mux.Route("/2.0/"+prefix, func(r chi.Router) {
|
||||
acmeHandler.Route(r)
|
||||
mux.Route("/2.0/acme", func(r chi.Router) {
|
||||
acmeAPI.Route(r)
|
||||
})
|
||||
}
|
||||
|
||||
// Admin API Router
|
||||
if cfg.AuthorityConfig.EnableAdmin {
|
||||
adminDB := auth.GetAdminDatabase()
|
||||
if adminDB != nil {
|
||||
acmeAdminResponder := adminAPI.NewACMEAdminResponder()
|
||||
policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB, acmeDB)
|
||||
adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder, policyAdminResponder)
|
||||
policyAdminResponder := adminAPI.NewPolicyAdminResponder()
|
||||
mux.Route("/admin", func(r chi.Router) {
|
||||
adminHandler.Route(r)
|
||||
adminAPI.Route(r, acmeAdminResponder, policyAdminResponder)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var scepAuthority *scep.Authority
|
||||
if ca.shouldServeSCEPEndpoints() {
|
||||
scepPrefix := "scep"
|
||||
scepAuthority, err := scep.New(auth, scep.AuthorityOptions{
|
||||
scepAuthority, err = scep.New(auth, scep.AuthorityOptions{
|
||||
Service: auth.GetSCEPService(),
|
||||
DNS: dns,
|
||||
Prefix: scepPrefix,
|
||||
|
@ -237,13 +231,12 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error creating SCEP authority")
|
||||
}
|
||||
scepRouterHandler := scepAPI.New(scepAuthority)
|
||||
|
||||
// According to the RFC (https://tools.ietf.org/html/rfc8894#section-7.10),
|
||||
// SCEP operations are performed using HTTP, so that's why the API is mounted
|
||||
// to the insecure mux.
|
||||
insecureMux.Route("/"+scepPrefix, func(r chi.Router) {
|
||||
scepRouterHandler.Route(r)
|
||||
scepAPI.Route(r)
|
||||
})
|
||||
|
||||
// The RFC also mentions usage of HTTPS, but seems to advise
|
||||
|
@ -253,7 +246,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
// as well as HTTPS can be used to request certificates
|
||||
// using SCEP.
|
||||
mux.Route("/"+scepPrefix, func(r chi.Router) {
|
||||
scepRouterHandler.Route(r)
|
||||
scepAPI.Route(r)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -280,7 +273,13 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
insecureHandler = logger.Middleware(insecureHandler)
|
||||
}
|
||||
|
||||
// Create context with all the necessary values.
|
||||
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
|
||||
|
||||
ca.srv = server.New(cfg.Address, handler, tlsConfig)
|
||||
ca.srv.BaseContext = func(net.Listener) context.Context {
|
||||
return baseContext
|
||||
}
|
||||
|
||||
// only start the insecure server if the insecure address is configured
|
||||
// and, currently, also only when it should serve SCEP endpoints.
|
||||
|
@ -290,11 +289,32 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
// will probably introduce more complexity in terms of graceful
|
||||
// reload.
|
||||
ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil)
|
||||
ca.insecureSrv.BaseContext = func(net.Listener) context.Context {
|
||||
return baseContext
|
||||
}
|
||||
}
|
||||
|
||||
return ca, nil
|
||||
}
|
||||
|
||||
// buildContext builds the server base context.
|
||||
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context {
|
||||
ctx := authority.NewContext(context.Background(), a)
|
||||
if authDB := a.GetDatabase(); authDB != nil {
|
||||
ctx = db.NewContext(ctx, authDB)
|
||||
}
|
||||
if adminDB := a.GetAdminDatabase(); adminDB != nil {
|
||||
ctx = admin.NewContext(ctx, adminDB)
|
||||
}
|
||||
if scepAuthority != nil {
|
||||
ctx = scep.NewContext(ctx, scepAuthority)
|
||||
}
|
||||
if acmeDB != nil {
|
||||
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Run starts the CA calling to the server ListenAndServe method.
|
||||
func (ca *CA) Run() error {
|
||||
var wg sync.WaitGroup
|
||||
|
|
|
@ -2,6 +2,7 @@ package ca
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
|
@ -281,7 +282,8 @@ ZEp7knvU2psWRw==
|
|||
assert.FatalError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
||||
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||
|
||||
if assert.Equals(t, rr.Code, tc.status) {
|
||||
body := &ClosingBuffer{rr.Body}
|
||||
|
@ -360,7 +362,8 @@ func TestCAProvisioners(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
||||
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||
|
||||
if assert.Equals(t, rr.Code, tc.status) {
|
||||
body := &ClosingBuffer{rr.Body}
|
||||
|
@ -426,7 +429,8 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
||||
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||
|
||||
if assert.Equals(t, rr.Code, tc.status) {
|
||||
body := &ClosingBuffer{rr.Body}
|
||||
|
@ -487,7 +491,8 @@ func TestCARoot(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
||||
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||
|
||||
if assert.Equals(t, rr.Code, tc.status) {
|
||||
body := &ClosingBuffer{rr.Body}
|
||||
|
@ -534,7 +539,8 @@ func TestCAHealth(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
||||
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||
|
||||
if assert.Equals(t, rr.Code, tc.status) {
|
||||
body := &ClosingBuffer{rr.Body}
|
||||
|
@ -628,7 +634,8 @@ func TestCARenew(t *testing.T) {
|
|||
rq.TLS = tc.tlsConnState
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
||||
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||
|
||||
if assert.Equals(t, rr.Code, tc.status) {
|
||||
body := &ClosingBuffer{rr.Body}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"encoding/hex"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
|
@ -77,7 +78,12 @@ func startCATestServer() *httptest.Server {
|
|||
panic(err)
|
||||
}
|
||||
// Use a httptest.Server instead
|
||||
return startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
|
||||
srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
|
||||
baseContext := buildContext(ca.auth, nil, nil, nil)
|
||||
srv.Config.BaseContext = func(net.Listener) context.Context {
|
||||
return baseContext
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {
|
||||
|
|
67
cas/vaultcas/auth/approle/approle.go
Normal file
67
cas/vaultcas/auth/approle/approle.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package approle
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/api/auth/approle"
|
||||
)
|
||||
|
||||
// AuthOptions defines the configuration options added using the
|
||||
// VaultOptions.AuthOptions field when AuthType is approle
|
||||
type AuthOptions struct {
|
||||
RoleID string `json:"roleID,omitempty"`
|
||||
SecretID string `json:"secretID,omitempty"`
|
||||
SecretIDFile string `json:"secretIDFile,omitempty"`
|
||||
SecretIDEnv string `json:"secretIDEnv,omitempty"`
|
||||
IsWrappingToken bool `json:"isWrappingToken,omitempty"`
|
||||
}
|
||||
|
||||
func NewApproleAuthMethod(mountPath string, options json.RawMessage) (*approle.AppRoleAuth, error) {
|
||||
var opts *AuthOptions
|
||||
|
||||
err := json.Unmarshal(options, &opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decoding AppRole auth options: %w", err)
|
||||
}
|
||||
|
||||
var approleAuth *approle.AppRoleAuth
|
||||
|
||||
var loginOptions []approle.LoginOption
|
||||
if mountPath != "" {
|
||||
loginOptions = append(loginOptions, approle.WithMountPath(mountPath))
|
||||
}
|
||||
if opts.IsWrappingToken {
|
||||
loginOptions = append(loginOptions, approle.WithWrappingToken())
|
||||
}
|
||||
|
||||
if opts.RoleID == "" {
|
||||
return nil, errors.New("you must set roleID")
|
||||
}
|
||||
|
||||
var sid approle.SecretID
|
||||
switch {
|
||||
case opts.SecretID != "" && opts.SecretIDFile == "" && opts.SecretIDEnv == "":
|
||||
sid = approle.SecretID{
|
||||
FromString: opts.SecretID,
|
||||
}
|
||||
case opts.SecretIDFile != "" && opts.SecretID == "" && opts.SecretIDEnv == "":
|
||||
sid = approle.SecretID{
|
||||
FromFile: opts.SecretIDFile,
|
||||
}
|
||||
case opts.SecretIDEnv != "" && opts.SecretIDFile == "" && opts.SecretID == "":
|
||||
sid = approle.SecretID{
|
||||
FromEnv: opts.SecretIDEnv,
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("you must set one of secretID, secretIDFile or secretIDEnv")
|
||||
}
|
||||
|
||||
approleAuth, err = approle.NewAppRoleAuth(opts.RoleID, &sid, loginOptions...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err)
|
||||
}
|
||||
|
||||
return approleAuth, nil
|
||||
}
|
195
cas/vaultcas/auth/approle/approle_test.go
Normal file
195
cas/vaultcas/auth/approle/approle_test.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package approle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
vault "github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
func testCAHelper(t *testing.T) (*url.URL, *vault.Client) {
|
||||
t.Helper()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.RequestURI == "/v1/auth/approle/login":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{
|
||||
"auth": {
|
||||
"client_token": "hvs.0000"
|
||||
}
|
||||
}`)
|
||||
case r.RequestURI == "/v1/auth/custom-approle/login":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{
|
||||
"auth": {
|
||||
"client_token": "hvs.9999"
|
||||
}
|
||||
}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(func() {
|
||||
srv.Close()
|
||||
})
|
||||
u, err := url.Parse(srv.URL)
|
||||
if err != nil {
|
||||
srv.Close()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := vault.DefaultConfig()
|
||||
config.Address = srv.URL
|
||||
|
||||
client, err := vault.NewClient(config)
|
||||
if err != nil {
|
||||
srv.Close()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return u, client
|
||||
}
|
||||
|
||||
func TestApprole_LoginMountPaths(t *testing.T) {
|
||||
caURL, _ := testCAHelper(t)
|
||||
|
||||
config := vault.DefaultConfig()
|
||||
config.Address = caURL.String()
|
||||
client, _ := vault.NewClient(config)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mountPath string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "ok default mount path",
|
||||
mountPath: "",
|
||||
token: "hvs.0000",
|
||||
},
|
||||
{
|
||||
name: "ok explicit mount path",
|
||||
mountPath: "approle",
|
||||
token: "hvs.0000",
|
||||
},
|
||||
{
|
||||
name: "ok custom mount path",
|
||||
mountPath: "custom-approle",
|
||||
token: "hvs.9999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
method, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`))
|
||||
if err != nil {
|
||||
t.Errorf("NewApproleAuthMethod() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
secret, err := client.Auth().Login(context.Background(), method)
|
||||
if err != nil {
|
||||
t.Errorf("Login() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
token, _ := secret.TokenID()
|
||||
if token != tt.token {
|
||||
t.Errorf("Token error got %v, expected %v", token, tt.token)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprole_NewApproleAuthMethod(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mountPath string
|
||||
raw string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"ok secret-id string",
|
||||
"",
|
||||
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000"}`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok secret-id string and wrapped",
|
||||
"",
|
||||
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok secret-id string and wrapped with custom mountPath",
|
||||
"approle2",
|
||||
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok secret-id file",
|
||||
"",
|
||||
`{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok secret-id env",
|
||||
"",
|
||||
`{"RoleID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"fail mandatory role-id",
|
||||
"",
|
||||
`{}`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail mandatory secret-id any",
|
||||
"",
|
||||
`{"RoleID": "0000-0000-0000-0000"}`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail multiple secret-id types id and env",
|
||||
"",
|
||||
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail multiple secret-id types id and file",
|
||||
"",
|
||||
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail multiple secret-id types env and file",
|
||||
"",
|
||||
`{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail multiple secret-id types all",
|
||||
"",
|
||||
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(tt.raw))
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Approle.NewApproleAuthMethod() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
49
cas/vaultcas/auth/kubernetes/kubernetes.go
Normal file
49
cas/vaultcas/auth/kubernetes/kubernetes.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package kubernetes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/api/auth/kubernetes"
|
||||
)
|
||||
|
||||
// AuthOptions defines the configuration options added using the
|
||||
// VaultOptions.AuthOptions field when AuthType is kubernetes
|
||||
type AuthOptions struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
TokenPath string `json:"tokenPath,omitempty"`
|
||||
}
|
||||
|
||||
func NewKubernetesAuthMethod(mountPath string, options json.RawMessage) (*kubernetes.KubernetesAuth, error) {
|
||||
var opts *AuthOptions
|
||||
|
||||
err := json.Unmarshal(options, &opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decoding Kubernetes auth options: %w", err)
|
||||
}
|
||||
|
||||
var kubernetesAuth *kubernetes.KubernetesAuth
|
||||
|
||||
var loginOptions []kubernetes.LoginOption
|
||||
if mountPath != "" {
|
||||
loginOptions = append(loginOptions, kubernetes.WithMountPath(mountPath))
|
||||
}
|
||||
if opts.TokenPath != "" {
|
||||
loginOptions = append(loginOptions, kubernetes.WithServiceAccountTokenPath(opts.TokenPath))
|
||||
}
|
||||
|
||||
if opts.Role == "" {
|
||||
return nil, errors.New("you must set role")
|
||||
}
|
||||
|
||||
kubernetesAuth, err = kubernetes.NewKubernetesAuth(
|
||||
opts.Role,
|
||||
loginOptions...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err)
|
||||
}
|
||||
|
||||
return kubernetesAuth, nil
|
||||
}
|
149
cas/vaultcas/auth/kubernetes/kubernetes_test.go
Normal file
149
cas/vaultcas/auth/kubernetes/kubernetes_test.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package kubernetes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
vault "github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
func testCAHelper(t *testing.T) (*url.URL, *vault.Client) {
|
||||
t.Helper()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.RequestURI == "/v1/auth/kubernetes/login":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{
|
||||
"auth": {
|
||||
"client_token": "hvs.0000"
|
||||
}
|
||||
}`)
|
||||
case r.RequestURI == "/v1/auth/custom-kubernetes/login":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{
|
||||
"auth": {
|
||||
"client_token": "hvs.9999"
|
||||
}
|
||||
}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"not found"}`)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(func() {
|
||||
srv.Close()
|
||||
})
|
||||
u, err := url.Parse(srv.URL)
|
||||
if err != nil {
|
||||
srv.Close()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := vault.DefaultConfig()
|
||||
config.Address = srv.URL
|
||||
|
||||
client, err := vault.NewClient(config)
|
||||
if err != nil {
|
||||
srv.Close()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return u, client
|
||||
}
|
||||
|
||||
func TestApprole_LoginMountPaths(t *testing.T) {
|
||||
caURL, _ := testCAHelper(t)
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
tokenPath := filepath.Join(path.Dir(filename), "token")
|
||||
|
||||
config := vault.DefaultConfig()
|
||||
config.Address = caURL.String()
|
||||
client, _ := vault.NewClient(config)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mountPath string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "ok default mount path",
|
||||
mountPath: "",
|
||||
token: "hvs.0000",
|
||||
},
|
||||
{
|
||||
name: "ok explicit mount path",
|
||||
mountPath: "kubernetes",
|
||||
token: "hvs.0000",
|
||||
},
|
||||
{
|
||||
name: "ok custom mount path",
|
||||
mountPath: "custom-kubernetes",
|
||||
token: "hvs.9999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
method, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(`{"role": "SomeRoleName", "tokenPath": "`+tokenPath+`"}`))
|
||||
if err != nil {
|
||||
t.Errorf("NewApproleAuthMethod() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
secret, err := client.Auth().Login(context.Background(), method)
|
||||
if err != nil {
|
||||
t.Errorf("Login() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
token, _ := secret.TokenID()
|
||||
if token != tt.token {
|
||||
t.Errorf("Token error got %v, expected %v", token, tt.token)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprole_NewApproleAuthMethod(t *testing.T) {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
tokenPath := filepath.Join(path.Dir(filename), "token")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mountPath string
|
||||
raw string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"ok secret-id string",
|
||||
"",
|
||||
`{"role": "SomeRoleName", "tokenPath": "` + tokenPath + `"}`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"fail mandatory role",
|
||||
"",
|
||||
`{}`,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(tt.raw))
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Kubernetes.NewKubernetesAuthMethod() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1
cas/vaultcas/auth/kubernetes/token
Normal file
1
cas/vaultcas/auth/kubernetes/token
Normal file
|
@ -0,0 +1 @@
|
|||
token
|
|
@ -15,9 +15,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/cas/apiv1"
|
||||
"github.com/smallstep/certificates/cas/vaultcas/auth/approle"
|
||||
"github.com/smallstep/certificates/cas/vaultcas/auth/kubernetes"
|
||||
|
||||
vault "github.com/hashicorp/vault/api"
|
||||
auth "github.com/hashicorp/vault/api/auth/approle"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -29,15 +30,14 @@ func init() {
|
|||
// VaultOptions defines the configuration options added using the
|
||||
// apiv1.Options.Config field.
|
||||
type VaultOptions struct {
|
||||
PKI string `json:"pki,omitempty"`
|
||||
PKIMountPath string `json:"pkiMountPath,omitempty"`
|
||||
PKIRoleDefault string `json:"pkiRoleDefault,omitempty"`
|
||||
PKIRoleRSA string `json:"pkiRoleRSA,omitempty"`
|
||||
PKIRoleEC string `json:"pkiRoleEC,omitempty"`
|
||||
PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"`
|
||||
RoleID string `json:"roleID,omitempty"`
|
||||
SecretID auth.SecretID `json:"secretID,omitempty"`
|
||||
AppRole string `json:"appRole,omitempty"`
|
||||
IsWrappingToken bool `json:"isWrappingToken,omitempty"`
|
||||
AuthType string `json:"authType,omitempty"`
|
||||
AuthMountPath string `json:"authMountPath,omitempty"`
|
||||
AuthOptions json.RawMessage `json:"authOptions,omitempty"`
|
||||
}
|
||||
|
||||
// VaultCAS implements a Certificate Authority Service using Hashicorp Vault.
|
||||
|
@ -77,28 +77,22 @@ func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) {
|
|||
return nil, fmt.Errorf("unable to initialize vault client: %w", err)
|
||||
}
|
||||
|
||||
var appRoleAuth *auth.AppRoleAuth
|
||||
if vc.IsWrappingToken {
|
||||
appRoleAuth, err = auth.NewAppRoleAuth(
|
||||
vc.RoleID,
|
||||
&vc.SecretID,
|
||||
auth.WithWrappingToken(),
|
||||
auth.WithMountPath(vc.AppRole),
|
||||
)
|
||||
} else {
|
||||
appRoleAuth, err = auth.NewAppRoleAuth(
|
||||
vc.RoleID,
|
||||
&vc.SecretID,
|
||||
auth.WithMountPath(vc.AppRole),
|
||||
)
|
||||
var method vault.AuthMethod
|
||||
switch vc.AuthType {
|
||||
case "kubernetes":
|
||||
method, err = kubernetes.NewKubernetesAuthMethod(vc.AuthMountPath, vc.AuthOptions)
|
||||
case "approle":
|
||||
method, err = approle.NewApproleAuthMethod(vc.AuthMountPath, vc.AuthOptions)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown auth type: %s, only 'kubernetes' and 'approle' currently supported", vc.AuthType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize AppRole auth method: %w", err)
|
||||
return nil, fmt.Errorf("unable to configure %s auth method: %w", vc.AuthType, err)
|
||||
}
|
||||
|
||||
authInfo, err := client.Auth().Login(ctx, appRoleAuth)
|
||||
authInfo, err := client.Auth().Login(ctx, method)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to login to AppRole auth method: %w", err)
|
||||
return nil, fmt.Errorf("unable to login to %s auth method: %w", vc.AuthType, err)
|
||||
}
|
||||
if authInfo == nil {
|
||||
return nil, errors.New("no auth info was returned after login")
|
||||
|
@ -134,7 +128,7 @@ func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv
|
|||
// GetCertificateAuthority returns the root certificate of the certificate
|
||||
// authority using the configured fingerprint.
|
||||
func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) {
|
||||
secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain")
|
||||
secret, err := v.client.Logical().Read(v.config.PKIMountPath + "/cert/ca_chain")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading ca chain: %w", err)
|
||||
}
|
||||
|
@ -190,7 +184,7 @@ func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv
|
|||
vaultReq := map[string]interface{}{
|
||||
"serial_number": formatSerialNumber(sn),
|
||||
}
|
||||
_, err := v.client.Logical().Write(v.config.PKI+"/revoke/", vaultReq)
|
||||
_, err := v.client.Logical().Write(v.config.PKIMountPath+"/revoke/", vaultReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error revoking certificate: %w", err)
|
||||
}
|
||||
|
@ -224,7 +218,7 @@ func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.
|
|||
"ttl": lifetime.Seconds(),
|
||||
}
|
||||
|
||||
secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, vaultReq)
|
||||
secret, err := v.client.Logical().Write(v.config.PKIMountPath+"/sign/"+vaultPKIRole, vaultReq)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error signing certificate: %w", err)
|
||||
}
|
||||
|
@ -247,21 +241,17 @@ func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.
|
|||
}
|
||||
|
||||
func loadOptions(config json.RawMessage) (*VaultOptions, error) {
|
||||
var vc *VaultOptions
|
||||
// setup default values
|
||||
vc := VaultOptions{
|
||||
PKIMountPath: "pki",
|
||||
PKIRoleDefault: "default",
|
||||
}
|
||||
|
||||
err := json.Unmarshal(config, &vc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decoding vaultCAS config: %w", err)
|
||||
}
|
||||
|
||||
if vc.PKI == "" {
|
||||
vc.PKI = "pki" // use default pki vault name
|
||||
}
|
||||
|
||||
if vc.PKIRoleDefault == "" {
|
||||
vc.PKIRoleDefault = "default" // use default pki role name
|
||||
}
|
||||
|
||||
if vc.PKIRoleRSA == "" {
|
||||
vc.PKIRoleRSA = vc.PKIRoleDefault
|
||||
}
|
||||
|
@ -272,23 +262,7 @@ func loadOptions(config json.RawMessage) (*VaultOptions, error) {
|
|||
vc.PKIRoleEd25519 = vc.PKIRoleDefault
|
||||
}
|
||||
|
||||
if vc.RoleID == "" {
|
||||
return nil, errors.New("vaultCAS config options must define `roleID`")
|
||||
}
|
||||
|
||||
if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" {
|
||||
return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`")
|
||||
}
|
||||
|
||||
if vc.PKI == "" {
|
||||
vc.PKI = "pki" // use default pki vault name
|
||||
}
|
||||
|
||||
if vc.AppRole == "" {
|
||||
vc.AppRole = "auth/approle"
|
||||
}
|
||||
|
||||
return vc, nil
|
||||
return &vc, nil
|
||||
}
|
||||
|
||||
func parseCertificates(pemCert string) []*x509.Certificate {
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"time"
|
||||
|
||||
vault "github.com/hashicorp/vault/api"
|
||||
auth "github.com/hashicorp/vault/api/auth/approle"
|
||||
"github.com/smallstep/certificates/cas/apiv1"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
@ -99,7 +98,7 @@ func testCAHelper(t *testing.T) (*url.URL, *vault.Client) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.RequestURI == "/v1/auth/auth/approle/login":
|
||||
case r.RequestURI == "/v1/auth/approle/login":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{
|
||||
"auth": {
|
||||
|
@ -183,11 +182,8 @@ func TestNew_register(t *testing.T) {
|
|||
CertificateAuthority: caURL.String(),
|
||||
CertificateAuthorityFingerprint: testRootFingerprint,
|
||||
Config: json.RawMessage(`{
|
||||
"PKI": "pki",
|
||||
"PKIRoleDefault": "pki-role",
|
||||
"RoleID": "roleID",
|
||||
"SecretID": {"FromString": "secretID"},
|
||||
"IsWrappingToken": false
|
||||
"AuthType": "approle",
|
||||
"AuthOptions": {"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}
|
||||
}`),
|
||||
})
|
||||
|
||||
|
@ -201,15 +197,11 @@ func TestVaultCAS_CreateCertificate(t *testing.T) {
|
|||
_, client := testCAHelper(t)
|
||||
|
||||
options := VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIMountPath: "pki",
|
||||
PKIRoleDefault: "role",
|
||||
PKIRoleRSA: "rsa",
|
||||
PKIRoleEC: "ec",
|
||||
PKIRoleEd25519: "ed25519",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromString: "secretID"},
|
||||
AppRole: "approle",
|
||||
IsWrappingToken: false,
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
|
@ -291,7 +283,7 @@ func TestVaultCAS_GetCertificateAuthority(t *testing.T) {
|
|||
}
|
||||
|
||||
options := VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIMountPath: "pki",
|
||||
}
|
||||
|
||||
rootCert := parseCertificates(testRootCertificate)[0]
|
||||
|
@ -335,15 +327,11 @@ func TestVaultCAS_RevokeCertificate(t *testing.T) {
|
|||
_, client := testCAHelper(t)
|
||||
|
||||
options := VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIMountPath: "pki",
|
||||
PKIRoleDefault: "role",
|
||||
PKIRoleRSA: "rsa",
|
||||
PKIRoleEC: "ec",
|
||||
PKIRoleEd25519: "ed25519",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromString: "secretID"},
|
||||
AppRole: "approle",
|
||||
IsWrappingToken: false,
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
|
@ -407,15 +395,11 @@ func TestVaultCAS_RenewCertificate(t *testing.T) {
|
|||
_, client := testCAHelper(t)
|
||||
|
||||
options := VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIMountPath: "pki",
|
||||
PKIRoleDefault: "role",
|
||||
PKIRoleRSA: "rsa",
|
||||
PKIRoleEC: "ec",
|
||||
PKIRoleEd25519: "ed25519",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromString: "secretID"},
|
||||
AppRole: "approle",
|
||||
IsWrappingToken: false,
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
|
@ -464,202 +448,66 @@ func TestVaultCAS_loadOptions(t *testing.T) {
|
|||
want *VaultOptions
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"ok mandatory with SecretID FromString",
|
||||
`{"RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIRoleDefault: "default",
|
||||
PKIRoleRSA: "default",
|
||||
PKIRoleEC: "default",
|
||||
PKIRoleEd25519: "default",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromString: "secretID"},
|
||||
AppRole: "auth/approle",
|
||||
IsWrappingToken: false,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok mandatory with SecretID FromFile",
|
||||
`{"RoleID": "roleID", "SecretID": {"FromFile": "secretID"}}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIRoleDefault: "default",
|
||||
PKIRoleRSA: "default",
|
||||
PKIRoleEC: "default",
|
||||
PKIRoleEd25519: "default",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromFile: "secretID"},
|
||||
AppRole: "auth/approle",
|
||||
IsWrappingToken: false,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok mandatory with SecretID FromEnv",
|
||||
`{"RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIRoleDefault: "default",
|
||||
PKIRoleRSA: "default",
|
||||
PKIRoleEC: "default",
|
||||
PKIRoleEd25519: "default",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromEnv: "secretID"},
|
||||
AppRole: "auth/approle",
|
||||
IsWrappingToken: false,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok mandatory PKIRole PKIRoleEd25519",
|
||||
`{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`,
|
||||
`{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519"}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIMountPath: "pki",
|
||||
PKIRoleDefault: "role",
|
||||
PKIRoleRSA: "role",
|
||||
PKIRoleEC: "role",
|
||||
PKIRoleEd25519: "ed25519",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromEnv: "secretID"},
|
||||
AppRole: "auth/approle",
|
||||
IsWrappingToken: false,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok mandatory PKIRole PKIRoleEC",
|
||||
`{"PKIRoleDefault": "role", "PKIRoleEC": "ec" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`,
|
||||
`{"PKIRoleDefault": "role", "PKIRoleEC": "ec"}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIMountPath: "pki",
|
||||
PKIRoleDefault: "role",
|
||||
PKIRoleRSA: "role",
|
||||
PKIRoleEC: "ec",
|
||||
PKIRoleEd25519: "role",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromEnv: "secretID"},
|
||||
AppRole: "auth/approle",
|
||||
IsWrappingToken: false,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok mandatory PKIRole PKIRoleRSA",
|
||||
`{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`,
|
||||
`{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa"}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIMountPath: "pki",
|
||||
PKIRoleDefault: "role",
|
||||
PKIRoleRSA: "rsa",
|
||||
PKIRoleEC: "role",
|
||||
PKIRoleEd25519: "role",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromEnv: "secretID"},
|
||||
AppRole: "auth/approle",
|
||||
IsWrappingToken: false,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519",
|
||||
`{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519", "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`,
|
||||
`{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIMountPath: "pki",
|
||||
PKIRoleDefault: "default",
|
||||
PKIRoleRSA: "rsa",
|
||||
PKIRoleEC: "ec",
|
||||
PKIRoleEd25519: "ed25519",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromEnv: "secretID"},
|
||||
AppRole: "auth/approle",
|
||||
IsWrappingToken: false,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519 with useless PKIRoleDefault",
|
||||
`{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519", "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`,
|
||||
`{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIMountPath: "pki",
|
||||
PKIRoleDefault: "role",
|
||||
PKIRoleRSA: "rsa",
|
||||
PKIRoleEC: "ec",
|
||||
PKIRoleEd25519: "ed25519",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromEnv: "secretID"},
|
||||
AppRole: "auth/approle",
|
||||
IsWrappingToken: false,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok mandatory with AppRole",
|
||||
`{"AppRole": "test", "RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIRoleDefault: "default",
|
||||
PKIRoleRSA: "default",
|
||||
PKIRoleEC: "default",
|
||||
PKIRoleEd25519: "default",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromString: "secretID"},
|
||||
AppRole: "test",
|
||||
IsWrappingToken: false,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ok mandatory with IsWrappingToken",
|
||||
`{"IsWrappingToken": true, "RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`,
|
||||
&VaultOptions{
|
||||
PKI: "pki",
|
||||
PKIRoleDefault: "default",
|
||||
PKIRoleRSA: "default",
|
||||
PKIRoleEC: "default",
|
||||
PKIRoleEd25519: "default",
|
||||
RoleID: "roleID",
|
||||
SecretID: auth.SecretID{FromString: "secretID"},
|
||||
AppRole: "auth/approle",
|
||||
IsWrappingToken: true,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"fail with SecretID FromFail",
|
||||
`{"RoleID": "roleID", "SecretID": {"FromFail": "secretID"}}`,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail with SecretID empty FromEnv",
|
||||
`{"RoleID": "roleID", "SecretID": {"FromEnv": ""}}`,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail with SecretID empty FromFile",
|
||||
`{"RoleID": "roleID", "SecretID": {"FromFile": ""}}`,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail with SecretID empty FromString",
|
||||
`{"RoleID": "roleID", "SecretID": {"FromString": ""}}`,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail mandatory with SecretID FromFail",
|
||||
`{"RoleID": "roleID", "SecretID": {"FromFail": "secretID"}}`,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fail missing RoleID",
|
||||
`{"SecretID": {"FromString": "secretID"}}`,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
24
db/db.go
24
db/db.go
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
@ -56,6 +57,29 @@ type AuthDB interface {
|
|||
Shutdown() error
|
||||
}
|
||||
|
||||
type dbKey struct{}
|
||||
|
||||
// NewContext adds the given authority database to the context.
|
||||
func NewContext(ctx context.Context, db AuthDB) context.Context {
|
||||
return context.WithValue(ctx, dbKey{}, db)
|
||||
}
|
||||
|
||||
// FromContext returns the current authority database from the given context.
|
||||
func FromContext(ctx context.Context) (db AuthDB, ok bool) {
|
||||
db, ok = ctx.Value(dbKey{}).(AuthDB)
|
||||
return
|
||||
}
|
||||
|
||||
// MustFromContext returns the current database from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustFromContext(ctx context.Context) AuthDB {
|
||||
if db, ok := FromContext(ctx); !ok {
|
||||
panic("authority database is not in the context")
|
||||
} else {
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
// CertificateStorer is an extension of AuthDB that allows to store
|
||||
// certificates.
|
||||
type CertificateStorer interface {
|
||||
|
|
1
go.mod
1
go.mod
|
@ -29,6 +29,7 @@ require (
|
|||
github.com/googleapis/gax-go/v2 v2.1.1
|
||||
github.com/hashicorp/vault/api v1.3.1
|
||||
github.com/hashicorp/vault/api/auth/approle v0.1.1
|
||||
github.com/hashicorp/vault/api/auth/kubernetes v0.1.0
|
||||
github.com/jhump/protoreflect v1.9.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.8 // indirect
|
||||
github.com/mattn/go-isatty v0.0.13 // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -449,6 +449,8 @@ github.com/hashicorp/vault/api v1.3.1 h1:pkDkcgTh47PRjY1NEFeofqR4W/HkNUi9qIakESO
|
|||
github.com/hashicorp/vault/api v1.3.1/go.mod h1:QeJoWxMFt+MsuWcYhmwRLwKEXrjwAFFywzhptMsTIUw=
|
||||
github.com/hashicorp/vault/api/auth/approle v0.1.1 h1:R5yA+xcNvw1ix6bDuWOaLOq2L4L77zDCVsethNw97xQ=
|
||||
github.com/hashicorp/vault/api/auth/approle v0.1.1/go.mod h1:mHOLgh//xDx4dpqXoq6tS8Ob0FoCFWLU2ibJ26Lfmag=
|
||||
github.com/hashicorp/vault/api/auth/kubernetes v0.1.0 h1:6BtyahbF4aQp8gg3ww0A/oIoqzbhpNP1spXU3nHE0n0=
|
||||
github.com/hashicorp/vault/api/auth/kubernetes v0.1.0/go.mod h1:Pdgk78uIs0mgDOLvc3a+h/vYIT9rznw2sz+ucuH9024=
|
||||
github.com/hashicorp/vault/sdk v0.3.0 h1:kR3dpxNkhh/wr6ycaJYqp6AFT/i2xaftbfnwZduTKEY=
|
||||
github.com/hashicorp/vault/sdk v0.3.0/go.mod h1:aZ3fNuL5VNydQk8GcLJ2TV8YCRVvyaakYkhZRoVuhj0=
|
||||
github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M=
|
||||
|
|
136
scep/api/api.go
136
scep/api/api.go
|
@ -38,8 +38,8 @@ type request struct {
|
|||
Message []byte
|
||||
}
|
||||
|
||||
// response is a SCEP server response.
|
||||
type response struct {
|
||||
// Response is a SCEP server Response.
|
||||
type Response struct {
|
||||
Operation string
|
||||
CACertNum int
|
||||
Data []byte
|
||||
|
@ -52,25 +52,48 @@ type handler struct {
|
|||
auth *scep.Authority
|
||||
}
|
||||
|
||||
// New returns a new SCEP API router.
|
||||
func New(auth *scep.Authority) api.RouterHandler {
|
||||
return &handler{
|
||||
auth: auth,
|
||||
// Route traffic and implement the Router interface.
|
||||
//
|
||||
// Deprecated: use scep.Route(r api.Router)
|
||||
func (h *handler) Route(r api.Router) {
|
||||
route(r, func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := scep.NewContext(r.Context(), h.auth)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// New returns a new SCEP API router.
|
||||
//
|
||||
// Deprecated: use scep.Route(r api.Router)
|
||||
func New(auth *scep.Authority) api.RouterHandler {
|
||||
return &handler{auth: auth}
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *handler) Route(r api.Router) {
|
||||
getLink := h.auth.GetLinkExplicit
|
||||
r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get))
|
||||
r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get))
|
||||
r.MethodFunc(http.MethodPost, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Post))
|
||||
r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post))
|
||||
func Route(r api.Router) {
|
||||
route(r, nil)
|
||||
}
|
||||
|
||||
func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc) {
|
||||
getHandler := lookupProvisioner(Get)
|
||||
postHandler := lookupProvisioner(Post)
|
||||
|
||||
// For backward compatibility.
|
||||
if middleware != nil {
|
||||
getHandler = middleware(getHandler)
|
||||
postHandler = middleware(postHandler)
|
||||
}
|
||||
|
||||
r.MethodFunc(http.MethodGet, "/{provisionerName}/*", getHandler)
|
||||
r.MethodFunc(http.MethodGet, "/{provisionerName}", getHandler)
|
||||
r.MethodFunc(http.MethodPost, "/{provisionerName}/*", postHandler)
|
||||
r.MethodFunc(http.MethodPost, "/{provisionerName}", postHandler)
|
||||
}
|
||||
|
||||
// Get handles all SCEP GET requests
|
||||
func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func Get(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeRequest(r)
|
||||
if err != nil {
|
||||
fail(w, fmt.Errorf("invalid scep get request: %w", err))
|
||||
|
@ -78,15 +101,15 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := r.Context()
|
||||
var res response
|
||||
var res Response
|
||||
|
||||
switch req.Operation {
|
||||
case opnGetCACert:
|
||||
res, err = h.GetCACert(ctx)
|
||||
res, err = GetCACert(ctx)
|
||||
case opnGetCACaps:
|
||||
res, err = h.GetCACaps(ctx)
|
||||
res, err = GetCACaps(ctx)
|
||||
case opnPKIOperation:
|
||||
res, err = h.PKIOperation(ctx, req)
|
||||
res, err = PKIOperation(ctx, req)
|
||||
default:
|
||||
err = fmt.Errorf("unknown operation: %s", req.Operation)
|
||||
}
|
||||
|
@ -100,20 +123,17 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Post handles all SCEP POST requests
|
||||
func (h *handler) Post(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func Post(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeRequest(r)
|
||||
if err != nil {
|
||||
fail(w, fmt.Errorf("invalid scep post request: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
var res response
|
||||
|
||||
var res Response
|
||||
switch req.Operation {
|
||||
case opnPKIOperation:
|
||||
res, err = h.PKIOperation(ctx, req)
|
||||
res, err = PKIOperation(r.Context(), req)
|
||||
default:
|
||||
err = fmt.Errorf("unknown operation: %s", req.Operation)
|
||||
}
|
||||
|
@ -127,7 +147,6 @@ func (h *handler) Post(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func decodeRequest(r *http.Request) (request, error) {
|
||||
|
||||
defer r.Body.Close()
|
||||
|
||||
method := r.Method
|
||||
|
@ -179,9 +198,8 @@ func decodeRequest(r *http.Request) (request, error) {
|
|||
|
||||
// lookupProvisioner loads the provisioner associated with the request.
|
||||
// Responds 404 if the provisioner does not exist.
|
||||
func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
||||
func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
name := chi.URLParam(r, "provisionerName")
|
||||
provisionerName, err := url.PathUnescape(name)
|
||||
if err != nil {
|
||||
|
@ -189,7 +207,9 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
p, err := h.auth.LoadProvisionerByName(provisionerName)
|
||||
ctx := r.Context()
|
||||
auth := scep.MustFromContext(ctx)
|
||||
p, err := auth.LoadProvisionerByName(provisionerName)
|
||||
if err != nil {
|
||||
fail(w, err)
|
||||
return
|
||||
|
@ -201,25 +221,24 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// GetCACert returns the CA certificates in a SCEP response
|
||||
func (h *handler) GetCACert(ctx context.Context) (response, error) {
|
||||
|
||||
certs, err := h.auth.GetCACertificates(ctx)
|
||||
func GetCACert(ctx context.Context) (Response, error) {
|
||||
auth := scep.MustFromContext(ctx)
|
||||
certs, err := auth.GetCACertificates(ctx)
|
||||
if err != nil {
|
||||
return response{}, err
|
||||
return Response{}, err
|
||||
}
|
||||
|
||||
if len(certs) == 0 {
|
||||
return response{}, errors.New("missing CA cert")
|
||||
return Response{}, errors.New("missing CA cert")
|
||||
}
|
||||
|
||||
res := response{
|
||||
res := Response{
|
||||
Operation: opnGetCACert,
|
||||
CACertNum: len(certs),
|
||||
}
|
||||
|
@ -232,7 +251,7 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) {
|
|||
// not signed or encrypted data has to be returned.
|
||||
data, err := microscep.DegenerateCertificates(certs)
|
||||
if err != nil {
|
||||
return response{}, err
|
||||
return Response{}, err
|
||||
}
|
||||
res.Data = data
|
||||
}
|
||||
|
@ -241,11 +260,11 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) {
|
|||
}
|
||||
|
||||
// GetCACaps returns the CA capabilities in a SCEP response
|
||||
func (h *handler) GetCACaps(ctx context.Context) (response, error) {
|
||||
func GetCACaps(ctx context.Context) (Response, error) {
|
||||
auth := scep.MustFromContext(ctx)
|
||||
caps := auth.GetCACaps(ctx)
|
||||
|
||||
caps := h.auth.GetCACaps(ctx)
|
||||
|
||||
res := response{
|
||||
res := Response{
|
||||
Operation: opnGetCACaps,
|
||||
Data: formatCapabilities(caps),
|
||||
}
|
||||
|
@ -254,13 +273,12 @@ func (h *handler) GetCACaps(ctx context.Context) (response, error) {
|
|||
}
|
||||
|
||||
// PKIOperation performs PKI operations and returns a SCEP response
|
||||
func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) {
|
||||
|
||||
func PKIOperation(ctx context.Context, req request) (Response, error) {
|
||||
// parse the message using microscep implementation
|
||||
microMsg, err := microscep.ParsePKIMessage(req.Message)
|
||||
if err != nil {
|
||||
// return the error, because we can't use the msg for creating a CertRep
|
||||
return response{}, err
|
||||
return Response{}, err
|
||||
}
|
||||
|
||||
// this is essentially doing the same as microscep.ParsePKIMessage, but
|
||||
|
@ -268,7 +286,7 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
|
|||
// wrapper for the microscep implementation.
|
||||
p7, err := pkcs7.Parse(microMsg.Raw)
|
||||
if err != nil {
|
||||
return response{}, err
|
||||
return Response{}, err
|
||||
}
|
||||
|
||||
// copy over properties to our internal PKIMessage
|
||||
|
@ -280,8 +298,9 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
|
|||
P7: p7,
|
||||
}
|
||||
|
||||
if err := h.auth.DecryptPKIEnvelope(ctx, msg); err != nil {
|
||||
return response{}, err
|
||||
auth := scep.MustFromContext(ctx)
|
||||
if err := auth.DecryptPKIEnvelope(ctx, msg); err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
|
||||
// NOTE: at this point we have sufficient information for returning nicely signed CertReps
|
||||
|
@ -293,13 +312,13 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
|
|||
// a certificate exists; then it will use RenewalReq. Adding the challenge check here may be a small breaking change for clients.
|
||||
// We'll have to see how it works out.
|
||||
if msg.MessageType == microscep.PKCSReq || msg.MessageType == microscep.RenewalReq {
|
||||
challengeMatches, err := h.auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword)
|
||||
challengeMatches, err := auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword)
|
||||
if err != nil {
|
||||
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password"))
|
||||
return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password"))
|
||||
}
|
||||
if !challengeMatches {
|
||||
// TODO: can this be returned safely to the client? In the end, if the password was correct, that gains a bit of info too.
|
||||
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided"))
|
||||
return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -311,12 +330,12 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
|
|||
// Authentication by the (self-signed) certificate with an optional challenge is required; supporting renewals incl. verification
|
||||
// of the client cert is not.
|
||||
|
||||
certRep, err := h.auth.SignCSR(ctx, csr, msg)
|
||||
certRep, err := auth.SignCSR(ctx, csr, msg)
|
||||
if err != nil {
|
||||
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err))
|
||||
return createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err))
|
||||
}
|
||||
|
||||
res := response{
|
||||
res := Response{
|
||||
Operation: opnPKIOperation,
|
||||
Data: certRep.Raw,
|
||||
Certificate: certRep.Certificate,
|
||||
|
@ -330,7 +349,7 @@ func formatCapabilities(caps []string) []byte {
|
|||
}
|
||||
|
||||
// writeResponse writes a SCEP response back to the SCEP client.
|
||||
func writeResponse(w http.ResponseWriter, res response) {
|
||||
func writeResponse(w http.ResponseWriter, res Response) {
|
||||
|
||||
if res.Error != nil {
|
||||
log.Error(w, res.Error)
|
||||
|
@ -350,19 +369,20 @@ func fail(w http.ResponseWriter, err error) {
|
|||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) {
|
||||
certRepMsg, err := h.auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error())
|
||||
func createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (Response, error) {
|
||||
auth := scep.MustFromContext(ctx)
|
||||
certRepMsg, err := auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error())
|
||||
if err != nil {
|
||||
return response{}, err
|
||||
return Response{}, err
|
||||
}
|
||||
return response{
|
||||
return Response{
|
||||
Operation: opnPKIOperation,
|
||||
Data: certRepMsg.Raw,
|
||||
Error: failError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func contentHeader(r response) string {
|
||||
func contentHeader(r Response) string {
|
||||
switch r.Operation {
|
||||
default:
|
||||
return "text/plain"
|
||||
|
|
|
@ -27,6 +27,29 @@ type Authority struct {
|
|||
signAuth SignAuthority
|
||||
}
|
||||
|
||||
type authorityKey struct{}
|
||||
|
||||
// NewContext adds the given authority to the context.
|
||||
func NewContext(ctx context.Context, a *Authority) context.Context {
|
||||
return context.WithValue(ctx, authorityKey{}, a)
|
||||
}
|
||||
|
||||
// FromContext returns the current authority from the given context.
|
||||
func FromContext(ctx context.Context) (a *Authority, ok bool) {
|
||||
a, ok = ctx.Value(authorityKey{}).(*Authority)
|
||||
return
|
||||
}
|
||||
|
||||
// MustFromContext returns the current authority from the given context. It will
|
||||
// panic if the authority is not in the context.
|
||||
func MustFromContext(ctx context.Context) *Authority {
|
||||
if a, ok := FromContext(ctx); !ok {
|
||||
panic("scep authority is not in the context")
|
||||
} else {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorityOptions required to create a new SCEP Authority.
|
||||
type AuthorityOptions struct {
|
||||
// Service provides the certificate chain, the signer and the decrypter to the Authority
|
||||
|
@ -163,7 +186,6 @@ func (a *Authority) GetCACertificates(ctx context.Context) ([]*x509.Certificate,
|
|||
|
||||
// DecryptPKIEnvelope decrypts an enveloped message
|
||||
func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) error {
|
||||
|
||||
p7c, err := pkcs7.Parse(msg.P7.Content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing pkcs7 content: %w", err)
|
||||
|
@ -210,7 +232,6 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err
|
|||
// SignCSR creates an x509.Certificate based on a CSR template and Cert Authority credentials
|
||||
// returns a new PKIMessage with CertRep data
|
||||
func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, msg *PKIMessage) (*PKIMessage, error) {
|
||||
|
||||
// TODO: intermediate storage of the request? In SCEP it's possible to request a csr/certificate
|
||||
// to be signed, which can be performed asynchronously / out-of-band. In that case a client can
|
||||
// poll for the status. It seems to be similar as what can happen in ACME, so might want to model
|
||||
|
@ -432,7 +453,6 @@ func (a *Authority) CreateFailureResponse(ctx context.Context, csr *x509.Certifi
|
|||
|
||||
// MatchChallengePassword verifies a SCEP challenge password
|
||||
func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) {
|
||||
|
||||
p, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
|
Loading…
Reference in a new issue