forked from TrueCloudLab/certificates
Merge pull request #914 from smallstep/context-authority
Retrieve authority from the context
This commit is contained in:
commit
539bfddba5
53 changed files with 2057 additions and 1576 deletions
|
@ -67,8 +67,11 @@ func (u *UpdateAccountRequest) Validate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccount is the handler resource for creating new ACME accounts.
|
// NewAccount is the handler resource for creating new ACME accounts.
|
||||||
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
payload, err := payloadFromContext(ctx)
|
payload, err := payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
|
@ -114,7 +117,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
eak, err := h.validateExternalAccountBinding(ctx, &nar)
|
eak, err := validateExternalAccountBinding(ctx, &nar)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
|
@ -125,7 +128,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
Contact: nar.Contact,
|
Contact: nar.Contact,
|
||||||
Status: acme.StatusValid,
|
Status: acme.StatusValid,
|
||||||
}
|
}
|
||||||
if err := h.db.CreateAccount(ctx, acc); err != nil {
|
if err := db.CreateAccount(ctx, acc); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -135,7 +138,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
|
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -146,15 +149,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
httpStatus = http.StatusOK
|
httpStatus = http.StatusOK
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkAccount(ctx, acc)
|
linker.LinkAccount(ctx, acc)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
|
w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
|
||||||
render.JSONStatus(w, acc, httpStatus)
|
render.JSONStatus(w, acc, httpStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrUpdateAccount is the api for updating an ACME account.
|
// GetOrUpdateAccount is the api for updating an ACME account.
|
||||||
func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
|
@ -186,16 +192,16 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
acc.Contact = uar.Contact
|
acc.Contact = uar.Contact
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.UpdateAccount(ctx, acc); err != nil {
|
if err := db.UpdateAccount(ctx, acc); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
|
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkAccount(ctx, acc)
|
linker.LinkAccount(ctx, acc)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
|
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
|
||||||
render.JSON(w, acc)
|
render.JSON(w, acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,8 +215,11 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
|
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
|
||||||
func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
|
@ -221,13 +230,14 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
|
|
||||||
|
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkOrdersByAccountID(ctx, orders)
|
linker.LinkOrdersByAccountID(ctx, orders)
|
||||||
|
|
||||||
render.JSON(w, orders)
|
render.JSON(w, orders)
|
||||||
logOrdersByAccount(w, orders)
|
logOrdersByAccount(w, orders)
|
||||||
|
|
|
@ -31,6 +31,22 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type fakeProvisioner struct{}
|
||||||
|
|
||||||
|
func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil }
|
||||||
|
func (*fakeProvisioner) GetID() string { return "" }
|
||||||
|
func (*fakeProvisioner) GetName() string { return "" }
|
||||||
|
func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 }
|
||||||
|
func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil }
|
||||||
|
|
||||||
func newProv() acme.Provisioner {
|
func newProv() acme.Provisioner {
|
||||||
// Initialize provisioners
|
// Initialize provisioners
|
||||||
p := &provisioner.ACME{
|
p := &provisioner.ACME{
|
||||||
|
@ -320,10 +336,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: accID}
|
acc := &acme.Account{ID: accID}
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
|
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
|
||||||
|
@ -339,11 +354,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetOrdersByAccountID(w, req)
|
GetOrdersByAccountID(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -387,6 +402,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-payload": func(t *testing.T) test {
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload expected in request context"),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
|
@ -395,6 +411,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
"fail/nil-payload": func(t *testing.T) test {
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, nil)
|
ctx := context.WithValue(context.Background(), payloadContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload expected in request context"),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
|
@ -403,6 +420,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "failed to "+
|
err: acme.NewError(acme.ErrorMalformedType, "failed to "+
|
||||||
|
@ -417,6 +435,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||||
|
@ -429,8 +448,9 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
|
@ -442,9 +462,10 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jwk expected in request context"),
|
err: acme.NewErrorISE("jwk expected in request context"),
|
||||||
|
@ -456,10 +477,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jwk expected in request context"),
|
err: acme.NewErrorISE("jwk expected in request context"),
|
||||||
|
@ -478,9 +500,9 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"),
|
err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"),
|
||||||
|
@ -495,7 +517,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -525,18 +547,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
scepProvisioner := &provisioner.SCEP{
|
|
||||||
Type: "SCEP",
|
|
||||||
Name: "test@scep-<test>provisioner.com",
|
|
||||||
}
|
|
||||||
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
}
|
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
|
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"),
|
err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"),
|
||||||
|
@ -575,8 +590,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
eak := &acme.ExternalAccountKey{
|
eak := &acme.ExternalAccountKey{
|
||||||
ID: "eakID",
|
ID: "eakID",
|
||||||
|
@ -623,8 +637,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||||
|
@ -659,11 +672,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
Status: acme.StatusValid,
|
Status: acme.StatusValid,
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
acc: acc,
|
acc: acc,
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
|
@ -688,8 +701,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
prov.RequireEAB = false
|
prov.RequireEAB = false
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||||
|
@ -743,8 +755,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -783,11 +794,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.NewAccount(w, req)
|
NewAccount(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -838,6 +849,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
|
@ -846,6 +858,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
|
@ -854,6 +867,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
"fail/no-payload": func(t *testing.T) test {
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload expected in request context"),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
|
@ -863,6 +877,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload expected in request context"),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
|
@ -872,6 +887,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
|
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
|
||||||
|
@ -886,6 +902,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||||
|
@ -918,10 +935,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(uar)
|
b, err := json.Marshal(uar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||||
|
@ -938,11 +954,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
uar := &UpdateAccountRequest{}
|
uar := &UpdateAccountRequest{}
|
||||||
b, err := json.Marshal(uar)
|
b, err := json.Marshal(uar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
|
@ -953,10 +969,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(uar)
|
b, err := json.Marshal(uar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||||
|
@ -970,11 +985,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/post-as-get": func(t *testing.T) test {
|
"ok/post-as-get": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
|
@ -983,11 +998,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetOrUpdateAccount(w, req)
|
GetOrUpdateAccount(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
|
@ -17,7 +17,7 @@ type ExternalAccountBinding struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
|
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
|
||||||
func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
||||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
|
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
|
||||||
|
@ -48,7 +48,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc
|
||||||
return nil, acmeErr
|
return nil, acmeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(*acme.Error); ok {
|
if _, ok := err.(*acme.Error); ok {
|
||||||
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
|
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
|
||||||
|
@ -111,7 +112,6 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
|
||||||
// o The "nonce" field MUST NOT be present
|
// o The "nonce" field MUST NOT be present
|
||||||
// o The "url" field MUST be set to the same value as the outer JWS
|
// o The "url" field MUST be set to the same value as the outer JWS
|
||||||
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
|
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
|
||||||
|
|
||||||
if jws == nil {
|
if jws == nil {
|
||||||
return "", acme.NewErrorISE("no JWS provided")
|
return "", acme.NewErrorISE("no JWS provided")
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,6 @@ import (
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_keysAreEqual(t *testing.T) {
|
func Test_keysAreEqual(t *testing.T) {
|
||||||
|
@ -100,8 +99,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{},
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -145,8 +143,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
createdAt := time.Now()
|
createdAt := time.Now()
|
||||||
return test{
|
return test{
|
||||||
|
@ -191,17 +188,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
scepProvisioner := &provisioner.SCEP{
|
|
||||||
Type: "SCEP",
|
|
||||||
Name: "test@scep-<test>provisioner.com",
|
|
||||||
}
|
|
||||||
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
}
|
|
||||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
|
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
|
||||||
|
@ -220,8 +210,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{},
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -266,8 +255,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{},
|
db: &acme.MockDB{},
|
||||||
|
@ -312,8 +300,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -360,8 +347,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -410,8 +396,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -460,8 +445,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -510,8 +494,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
createdAt := time.Now()
|
createdAt := time.Now()
|
||||||
return test{
|
return test{
|
||||||
|
@ -568,8 +551,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -616,8 +598,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
createdAt := time.Now()
|
createdAt := time.Now()
|
||||||
boundAt := time.Now().Add(1 * time.Second)
|
boundAt := time.Now().Add(1 * time.Second)
|
||||||
|
@ -676,8 +657,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -734,8 +714,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -789,8 +768,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -845,8 +823,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
prov := newACMEProv(t)
|
prov := newACMEProv(t)
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
ctx := context.WithValue(context.Background(), jwkContextKey, nil)
|
ctx := context.WithValue(context.Background(), jwkContextKey, nil)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -873,10 +850,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
|
||||||
db: tc.db,
|
got, err := validateExternalAccountBinding(ctx, tc.nar)
|
||||||
}
|
|
||||||
got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar)
|
|
||||||
wantErr := tc.err != nil
|
wantErr := tc.err != nil
|
||||||
gotErr := err != nil
|
gotErr := err != nil
|
||||||
if wantErr != gotErr {
|
if wantErr != gotErr {
|
||||||
|
|
|
@ -2,12 +2,10 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -16,6 +14,7 @@ import (
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/certificates/api/render"
|
"github.com/smallstep/certificates/api/render"
|
||||||
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,111 +38,152 @@ type payloadInfo struct {
|
||||||
isEmptyJSON bool
|
isEmptyJSON bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler is the ACME API request handler.
|
|
||||||
type Handler struct {
|
|
||||||
db acme.DB
|
|
||||||
backdate provisioner.Duration
|
|
||||||
ca acme.CertificateAuthority
|
|
||||||
linker Linker
|
|
||||||
validateChallengeOptions *acme.ValidateChallengeOptions
|
|
||||||
prerequisitesChecker func(ctx context.Context) (bool, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandlerOptions required to create a new ACME API request handler.
|
// HandlerOptions required to create a new ACME API request handler.
|
||||||
type HandlerOptions struct {
|
type HandlerOptions struct {
|
||||||
Backdate provisioner.Duration
|
// DB storage backend that implements the acme.DB interface.
|
||||||
// DB storage backend that impements the acme.DB interface.
|
//
|
||||||
|
// Deprecated: use acme.NewContex(context.Context, acme.DB)
|
||||||
DB acme.DB
|
DB acme.DB
|
||||||
|
|
||||||
|
// CA is the certificate authority interface.
|
||||||
|
//
|
||||||
|
// Deprecated: use authority.NewContext(context.Context, *authority.Authority)
|
||||||
|
CA acme.CertificateAuthority
|
||||||
|
|
||||||
|
// Backdate is the duration that the CA will subtract from the current time
|
||||||
|
// to set the NotBefore in the certificate.
|
||||||
|
Backdate provisioner.Duration
|
||||||
|
|
||||||
// DNS the host used to generate accurate ACME links. By default the authority
|
// DNS the host used to generate accurate ACME links. By default the authority
|
||||||
// will use the Host from the request, so this value will only be used if
|
// will use the Host from the request, so this value will only be used if
|
||||||
// request.Host is empty.
|
// request.Host is empty.
|
||||||
DNS string
|
DNS string
|
||||||
|
|
||||||
// Prefix is a URL path prefix under which the ACME api is served. This
|
// Prefix is a URL path prefix under which the ACME api is served. This
|
||||||
// prefix is required to generate accurate ACME links.
|
// prefix is required to generate accurate ACME links.
|
||||||
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
|
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
|
||||||
// "acme" is the prefix from which the ACME api is accessed.
|
// "acme" is the prefix from which the ACME api is accessed.
|
||||||
Prefix string
|
Prefix string
|
||||||
CA acme.CertificateAuthority
|
|
||||||
// PrerequisitesChecker checks if all prerequisites for serving ACME are
|
// PrerequisitesChecker checks if all prerequisites for serving ACME are
|
||||||
// met by the CA configuration.
|
// met by the CA configuration.
|
||||||
PrerequisitesChecker func(ctx context.Context) (bool, error)
|
PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||||
|
return authority.MustFromContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handler is the ACME API request handler.
|
||||||
|
type handler struct {
|
||||||
|
opts *HandlerOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route traffic and implement the Router interface. For backward compatibility
|
||||||
|
// this route adds will add a new middleware that will set the ACME components
|
||||||
|
// on the context.
|
||||||
|
//
|
||||||
|
// Note: this method is deprecated in step-ca, other applications can still use
|
||||||
|
// this to support ACME, but the recommendation is to use use
|
||||||
|
// api.Route(api.Router) and acme.NewContext() instead.
|
||||||
|
func (h *handler) Route(r api.Router) {
|
||||||
|
client := acme.NewClient()
|
||||||
|
linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix)
|
||||||
|
route(r, func(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
|
||||||
|
ctx = authority.NewContext(ctx, ca)
|
||||||
|
}
|
||||||
|
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// NewHandler returns a new ACME API handler.
|
// NewHandler returns a new ACME API handler.
|
||||||
func NewHandler(ops HandlerOptions) api.RouterHandler {
|
//
|
||||||
transport := &http.Transport{
|
// Note: this method is deprecated in step-ca, other applications can still use
|
||||||
TLSClientConfig: &tls.Config{
|
// this to support ACME, but the recommendation is to use use
|
||||||
InsecureSkipVerify: true,
|
// api.Route(api.Router) and acme.NewContext() instead.
|
||||||
},
|
func NewHandler(opts HandlerOptions) api.RouterHandler {
|
||||||
}
|
return &handler{
|
||||||
client := http.Client{
|
opts: &opts,
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
Transport: transport,
|
|
||||||
}
|
|
||||||
dialer := &net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
}
|
|
||||||
prerequisitesChecker := func(ctx context.Context) (bool, error) {
|
|
||||||
// by default all prerequisites are met
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
if ops.PrerequisitesChecker != nil {
|
|
||||||
prerequisitesChecker = ops.PrerequisitesChecker
|
|
||||||
}
|
|
||||||
return &Handler{
|
|
||||||
ca: ops.CA,
|
|
||||||
db: ops.DB,
|
|
||||||
backdate: ops.Backdate,
|
|
||||||
linker: NewLinker(ops.DNS, ops.Prefix),
|
|
||||||
validateChallengeOptions: &acme.ValidateChallengeOptions{
|
|
||||||
HTTPGet: client.Get,
|
|
||||||
LookupTxt: net.LookupTXT,
|
|
||||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
|
||||||
return tls.DialWithDialer(dialer, network, addr, config)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
prerequisitesChecker: prerequisitesChecker,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route traffic and implement the Router interface.
|
// Route traffic and implement the Router interface. This method requires that
|
||||||
func (h *Handler) Route(r api.Router) {
|
// all the acme components, authority, db, client, linker, and prerequisite
|
||||||
getPath := h.linker.GetUnescapedPathSuffix
|
// checker to be present in the context.
|
||||||
// Standard ACME API
|
func Route(r api.Router) {
|
||||||
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
route(r, nil)
|
||||||
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
}
|
||||||
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
|
||||||
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
|
||||||
|
|
||||||
|
func route(r api.Router, middleware func(next nextHTTP) nextHTTP) {
|
||||||
|
commonMiddleware := func(next nextHTTP) nextHTTP {
|
||||||
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Linker middleware gets the provisioner and current url from the
|
||||||
|
// request and sets them in the context.
|
||||||
|
linker := acme.MustLinkerFromContext(r.Context())
|
||||||
|
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
if middleware != nil {
|
||||||
|
handler = middleware(handler)
|
||||||
|
}
|
||||||
|
return handler
|
||||||
|
}
|
||||||
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
||||||
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))))
|
return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
|
||||||
}
|
}
|
||||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||||
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
|
return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
|
||||||
}
|
}
|
||||||
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||||
return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))
|
return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
|
||||||
}
|
}
|
||||||
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
|
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
|
||||||
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next)))
|
return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
|
||||||
}
|
}
|
||||||
|
|
||||||
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
|
getPath := acme.GetUnescapedPathSuffix
|
||||||
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
|
|
||||||
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented))
|
// Standard ACME API
|
||||||
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder))
|
r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||||
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||||
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID)))
|
r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||||
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||||
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
|
r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||||
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
|
commonMiddleware(GetDirectory))
|
||||||
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||||
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert))
|
commonMiddleware(GetDirectory))
|
||||||
|
|
||||||
|
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
|
||||||
|
extractPayloadByJWK(NewAccount))
|
||||||
|
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
|
||||||
|
extractPayloadByKid(GetOrUpdateAccount))
|
||||||
|
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
|
||||||
|
extractPayloadByKid(NotImplemented))
|
||||||
|
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
|
||||||
|
extractPayloadByKid(NewOrder))
|
||||||
|
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
|
||||||
|
extractPayloadByKid(isPostAsGet(GetOrder)))
|
||||||
|
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
|
||||||
|
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
|
||||||
|
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
|
||||||
|
extractPayloadByKid(FinalizeOrder))
|
||||||
|
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
|
||||||
|
extractPayloadByKid(isPostAsGet(GetAuthorization)))
|
||||||
|
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
|
||||||
|
extractPayloadByKid(GetChallenge))
|
||||||
|
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
|
||||||
|
extractPayloadByKid(isPostAsGet(GetCertificate)))
|
||||||
|
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
|
||||||
|
extractPayloadByKidOrJWK(RevokeCert))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNonce just sets the right header since a Nonce is added to each response
|
// GetNonce just sets the right header since a Nonce is added to each response
|
||||||
// by middleware by default.
|
// by middleware by default.
|
||||||
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
|
func GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == "HEAD" {
|
if r.Method == "HEAD" {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
} else {
|
} else {
|
||||||
|
@ -179,7 +219,7 @@ func (d *Directory) ToLog() (interface{}, error) {
|
||||||
|
|
||||||
// GetDirectory is the ACME resource for returning a directory configuration
|
// GetDirectory is the ACME resource for returning a directory configuration
|
||||||
// for client configuration.
|
// for client configuration.
|
||||||
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
func GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -187,12 +227,13 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
render.JSON(w, &Directory{
|
render.JSON(w, &Directory{
|
||||||
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
|
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
|
||||||
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
|
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
|
||||||
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
|
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
|
||||||
RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType),
|
RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
|
||||||
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType),
|
KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
|
||||||
Meta: Meta{
|
Meta: Meta{
|
||||||
ExternalAccountRequired: acmeProv.RequireEAB,
|
ExternalAccountRequired: acmeProv.RequireEAB,
|
||||||
},
|
},
|
||||||
|
@ -201,19 +242,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
||||||
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
||||||
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
|
func NotImplemented(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthorization ACME api for retrieving an Authz.
|
// GetAuthorization ACME api for retrieving an Authz.
|
||||||
func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
func GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
||||||
return
|
return
|
||||||
|
@ -223,20 +267,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||||
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = az.UpdateStatus(ctx, h.db); err != nil {
|
if err = az.UpdateStatus(ctx, db); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkAuthorization(ctx, az)
|
linker.LinkAuthorization(ctx, az)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
|
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
|
||||||
render.JSON(w, az)
|
render.JSON(w, az)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetChallenge ACME api for retrieving a Challenge.
|
// GetChallenge ACME api for retrieving a Challenge.
|
||||||
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
|
@ -257,7 +304,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
// we'll just ignore the body.
|
// we'll just ignore the body.
|
||||||
|
|
||||||
azID := chi.URLParam(r, "authzID")
|
azID := chi.URLParam(r, "authzID")
|
||||||
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
||||||
return
|
return
|
||||||
|
@ -273,29 +320,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
|
if err = ch.Validate(ctx, db, jwk); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
|
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkChallenge(ctx, ch, azID)
|
linker.LinkChallenge(ctx, ch, azID)
|
||||||
|
|
||||||
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
|
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
|
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
|
||||||
render.JSON(w, ch)
|
render.JSON(w, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificate ACME api for retrieving a Certificate.
|
// GetCertificate ACME api for retrieving a Certificate.
|
||||||
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
func GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
certID := chi.URLParam(r, "certID")
|
|
||||||
|
|
||||||
cert, err := h.db.GetCertificate(ctx, certID)
|
certID := chi.URLParam(r, "certID")
|
||||||
|
cert, err := db.GetCertificate(ctx, certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
||||||
return
|
return
|
||||||
|
|
|
@ -3,6 +3,7 @@ package api
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
@ -19,11 +20,33 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mockClient struct {
|
||||||
|
get func(url string) (*http.Response, error)
|
||||||
|
lookupTxt func(name string) ([]string, error)
|
||||||
|
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) }
|
||||||
|
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
|
||||||
|
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
|
return m.tlsDial(network, addr, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) {
|
||||||
|
t.Helper()
|
||||||
|
fn := mustAuthority
|
||||||
|
t.Cleanup(func() {
|
||||||
|
mustAuthority = fn
|
||||||
|
})
|
||||||
|
mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandler_GetNonce(t *testing.T) {
|
func TestHandler_GetNonce(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -38,10 +61,10 @@ func TestHandler_GetNonce(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
// h := &Handler{}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
req.Method = tt.name
|
req.Method = tt.name
|
||||||
h.GetNonce(w, req)
|
GetNonce(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -52,7 +75,8 @@ func TestHandler_GetNonce(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_GetDirectory(t *testing.T) {
|
func TestHandler_GetDirectory(t *testing.T) {
|
||||||
linker := NewLinker("ca.smallstep.com", "acme")
|
linker := acme.NewLinker("ca.smallstep.com", "acme")
|
||||||
|
_ = linker
|
||||||
type test struct {
|
type test struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
statusCode int
|
statusCode int
|
||||||
|
@ -61,23 +85,14 @@ func TestHandler_GetDirectory(t *testing.T) {
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-provisioner": func(t *testing.T) test {
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: context.Background(),
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"),
|
err: acme.NewErrorISE("provisioner is not in context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/different-provisioner": func(t *testing.T) test {
|
"fail/different-provisioner": func(t *testing.T) test {
|
||||||
prov := &provisioner.SCEP{
|
ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{})
|
||||||
Type: "SCEP",
|
|
||||||
Name: "test@scep-<test>provisioner.com",
|
|
||||||
}
|
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
|
@ -88,8 +103,7 @@ func TestHandler_GetDirectory(t *testing.T) {
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
expDir := Directory{
|
expDir := Directory{
|
||||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||||
|
@ -108,8 +122,7 @@ func TestHandler_GetDirectory(t *testing.T) {
|
||||||
prov.RequireEAB = true
|
prov.RequireEAB = true
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
expDir := Directory{
|
expDir := Directory{
|
||||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||||
|
@ -130,11 +143,11 @@ func TestHandler_GetDirectory(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: linker}
|
ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetDirectory(w, req)
|
GetDirectory(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -219,7 +232,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{},
|
db: &acme.MockDB{},
|
||||||
|
@ -285,10 +298,9 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
|
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||||
|
@ -304,11 +316,11 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetAuthorization(w, req)
|
GetAuthorization(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -447,11 +459,11 @@ func TestHandler_GetCertificate(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetCertificate(w, req)
|
GetCertificate(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -491,7 +503,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
db acme.DB
|
db acme.DB
|
||||||
vco *acme.ValidateChallengeOptions
|
vc acme.Client
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
statusCode int
|
statusCode int
|
||||||
ch *acme.Challenge
|
ch *acme.Challenge
|
||||||
|
@ -500,6 +512,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
|
@ -507,6 +520,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), accContextKey, nil),
|
ctx: context.WithValue(context.Background(), accContextKey, nil),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
|
@ -516,6 +530,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload expected in request context"),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
|
@ -523,10 +538,11 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-payload": func(t *testing.T) test {
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload expected in request context"),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
|
@ -534,7 +550,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
@ -553,7 +569,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
@ -572,7 +588,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/no-jwk": func(t *testing.T) test {
|
"fail/no-jwk": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
@ -591,7 +607,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-jwk": func(t *testing.T) test {
|
"fail/nil-jwk": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||||
|
@ -611,7 +627,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/validate-challenge-error": func(t *testing.T) test {
|
"fail/validate-challenge-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
@ -639,8 +655,8 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
return acme.NewErrorISE("force")
|
return acme.NewErrorISE("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
vco: &acme.ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(string) (*http.Response, error) {
|
get: func(string) (*http.Response, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -651,14 +667,13 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
_pub := _jwk.Public()
|
_pub := _jwk.Public()
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
|
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -690,8 +705,8 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
URL: u,
|
URL: u,
|
||||||
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
||||||
},
|
},
|
||||||
vco: &acme.ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(string) (*http.Response, error) {
|
get: func(string) (*http.Response, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -703,11 +718,11 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetChallenge(w, req)
|
GetChallenge(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/keyutil"
|
"go.step.sm/crypto/keyutil"
|
||||||
|
|
||||||
|
@ -31,39 +30,11 @@ func logNonce(w http.ResponseWriter, nonce string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// baseURLFromRequest determines the base URL which should be used for
|
|
||||||
// constructing link URLs in e.g. the ACME directory result by taking the
|
|
||||||
// request Host into consideration.
|
|
||||||
//
|
|
||||||
// If the Request.Host is an empty string, we return an empty string, to
|
|
||||||
// indicate that the configured URL values should be used instead. If this
|
|
||||||
// function returns a non-empty result, then this should be used in
|
|
||||||
// constructing ACME link URLs.
|
|
||||||
func baseURLFromRequest(r *http.Request) *url.URL {
|
|
||||||
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
|
|
||||||
// for an implementation that allows HTTP requests using the x-forwarded-proto
|
|
||||||
// header.
|
|
||||||
|
|
||||||
if r.Host == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &url.URL{Scheme: "https", Host: r.Host}
|
|
||||||
}
|
|
||||||
|
|
||||||
// baseURLFromRequest is a middleware that extracts and caches the baseURL
|
|
||||||
// from the request.
|
|
||||||
// E.g. https://ca.smallstep.com/
|
|
||||||
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
|
|
||||||
next(w, r.WithContext(ctx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNonce is a middleware that adds a nonce to the response header.
|
// addNonce is a middleware that adds a nonce to the response header.
|
||||||
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
func addNonce(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
nonce, err := h.db.CreateNonce(r.Context())
|
db := acme.MustDatabaseFromContext(r.Context())
|
||||||
|
nonce, err := db.CreateNonce(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
|
@ -77,25 +48,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||||
|
|
||||||
// addDirLink is a middleware that adds a 'Link' response reader with the
|
// addDirLink is a middleware that adds a 'Link' response reader with the
|
||||||
// directory index url.
|
// directory index url.
|
||||||
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
func addDirLink(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index"))
|
ctx := r.Context()
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
|
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||||
next(w, r)
|
next(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// verifyContentType is a middleware that verifies that content type is
|
// verifyContentType is a middleware that verifies that content type is
|
||||||
// application/jose+json.
|
// application/jose+json.
|
||||||
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
func verifyContentType(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
var expected []string
|
|
||||||
p, err := provisionerFromContext(r.Context())
|
p, err := provisionerFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
|
u := &url.URL{
|
||||||
|
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
var expected []string
|
||||||
if strings.Contains(r.URL.String(), u.EscapedPath()) {
|
if strings.Contains(r.URL.String(), u.EscapedPath()) {
|
||||||
// GET /certificate requests allow a greater range of content types.
|
// GET /certificate requests allow a greater range of content types.
|
||||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||||
|
@ -117,7 +94,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
|
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
|
||||||
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
func parseJWS(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -149,10 +126,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||||
// * “nonce” (defined in Section 6.5)
|
// * “nonce” (defined in Section 6.5)
|
||||||
// * “url” (defined in Section 6.4)
|
// * “url” (defined in Section 6.4)
|
||||||
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||||
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
func validateJWS(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(r.Context())
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
|
||||||
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
|
@ -202,7 +181,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the validity/freshness of the Nonce.
|
// Check the validity/freshness of the Nonce.
|
||||||
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -235,10 +214,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
|
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
|
||||||
// in the context. Make sure to parse and validate the JWS before running this
|
// in the context. Make sure to parse and validate the JWS before running this
|
||||||
// middleware.
|
// middleware.
|
||||||
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
func extractJWK(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(r.Context())
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
|
||||||
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
|
@ -264,7 +245,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
|
|
||||||
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
|
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
|
||||||
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
|
acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, acme.ErrNotFound):
|
case errors.Is(err, acme.ErrNotFound):
|
||||||
// For NewAccount and Revoke requests ...
|
// For NewAccount and Revoke requests ...
|
||||||
|
@ -283,63 +264,44 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupProvisioner loads the provisioner associated with the request.
|
|
||||||
// Responds 404 if the provisioner does not exist.
|
|
||||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := r.Context()
|
|
||||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
|
||||||
name, err := url.PathUnescape(nameEscaped)
|
|
||||||
if err != nil {
|
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p, err := h.ca.LoadProvisionerByName(name)
|
|
||||||
if err != nil {
|
|
||||||
render.Error(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
acmeProv, ok := p.(*provisioner.ACME)
|
|
||||||
if !ok {
|
|
||||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
|
|
||||||
next(w, r.WithContext(ctx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkPrerequisites checks if all prerequisites for serving ACME
|
// checkPrerequisites checks if all prerequisites for serving ACME
|
||||||
// are met by the CA configuration.
|
// are met by the CA configuration.
|
||||||
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
|
func checkPrerequisites(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ok, err := h.prerequisitesChecker(ctx)
|
// If the function is not set assume that all prerequisites are met.
|
||||||
if err != nil {
|
checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
if ok {
|
||||||
return
|
ok, err := checkFunc(ctx)
|
||||||
|
if err != nil {
|
||||||
|
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if !ok {
|
next(w, r)
|
||||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next(w, r.WithContext(ctx))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupJWK loads the JWK associated with the acme account referenced by the
|
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||||
// kid parameter of the signed payload.
|
// kid parameter of the signed payload.
|
||||||
// Make sure to parse and validate the JWS before running this middleware.
|
// Make sure to parse and validate the JWS before running this middleware.
|
||||||
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
func lookupJWK(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
jws, err := jwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
|
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
|
||||||
kid := jws.Signatures[0].Protected.KeyID
|
kid := jws.Signatures[0].Protected.KeyID
|
||||||
if !strings.HasPrefix(kid, kidPrefix) {
|
if !strings.HasPrefix(kid, kidPrefix) {
|
||||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
|
@ -349,7 +311,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
|
|
||||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||||
acc, err := h.db.GetAccount(ctx, accID)
|
acc, err := db.GetAccount(ctx, accID)
|
||||||
switch {
|
switch {
|
||||||
case nosql.IsErrNotFound(err):
|
case nosql.IsErrNotFound(err):
|
||||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||||
|
@ -372,7 +334,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||||
|
|
||||||
// extractOrLookupJWK forwards handling to either extractJWK or
|
// extractOrLookupJWK forwards handling to either extractJWK or
|
||||||
// lookupJWK based on the presence of a JWK or a KID, respectively.
|
// lookupJWK based on the presence of a JWK or a KID, respectively.
|
||||||
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
func extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
|
@ -385,13 +347,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||||
// and it can be used to check if a JWK exists. This flow is used when the ACME client
|
// and it can be used to check if a JWK exists. This flow is used when the ACME client
|
||||||
// signed the payload with a certificate private key.
|
// signed the payload with a certificate private key.
|
||||||
if canExtractJWKFrom(jws) {
|
if canExtractJWKFrom(jws) {
|
||||||
h.extractJWK(next)(w, r)
|
extractJWK(next)(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
|
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
|
||||||
// signed the payload with an account private key.
|
// signed the payload with an account private key.
|
||||||
h.lookupJWK(next)(w, r)
|
lookupJWK(next)(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -408,7 +370,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
|
||||||
|
|
||||||
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
||||||
// Make sure to parse and validate the JWS before running this middleware.
|
// Make sure to parse and validate the JWS before running this middleware.
|
||||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
|
@ -440,7 +402,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
|
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
|
||||||
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
func isPostAsGet(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
payload, err := payloadFromContext(r.Context())
|
payload, err := payloadFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -462,16 +424,12 @@ type ContextKey string
|
||||||
const (
|
const (
|
||||||
// accContextKey account key
|
// accContextKey account key
|
||||||
accContextKey = ContextKey("acc")
|
accContextKey = ContextKey("acc")
|
||||||
// baseURLContextKey baseURL key
|
|
||||||
baseURLContextKey = ContextKey("baseURL")
|
|
||||||
// jwsContextKey jws key
|
// jwsContextKey jws key
|
||||||
jwsContextKey = ContextKey("jws")
|
jwsContextKey = ContextKey("jws")
|
||||||
// jwkContextKey jwk key
|
// jwkContextKey jwk key
|
||||||
jwkContextKey = ContextKey("jwk")
|
jwkContextKey = ContextKey("jwk")
|
||||||
// payloadContextKey payload key
|
// payloadContextKey payload key
|
||||||
payloadContextKey = ContextKey("payload")
|
payloadContextKey = ContextKey("payload")
|
||||||
// provisionerContextKey provisioner key
|
|
||||||
provisionerContextKey = ContextKey("provisioner")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// accountFromContext searches the context for an ACME account. Returns the
|
// accountFromContext searches the context for an ACME account. Returns the
|
||||||
|
@ -484,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) {
|
||||||
return val, nil
|
return val, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// baseURLFromContext returns the baseURL if one is stored in the context.
|
|
||||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
|
||||||
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
|
|
||||||
if !ok || val == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
|
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
|
||||||
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
|
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
|
||||||
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
|
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
|
||||||
|
@ -514,29 +463,26 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
|
||||||
// provisionerFromContext searches the context for a provisioner. Returns the
|
// provisionerFromContext searches the context for a provisioner. Returns the
|
||||||
// provisioner or an error.
|
// provisioner or an error.
|
||||||
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
|
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
|
||||||
val := ctx.Value(provisionerContextKey)
|
p, ok := acme.ProvisionerFromContext(ctx)
|
||||||
if val == nil {
|
if !ok || p == nil {
|
||||||
return nil, acme.NewErrorISE("provisioner expected in request context")
|
return nil, acme.NewErrorISE("provisioner expected in request context")
|
||||||
}
|
}
|
||||||
pval, ok := val.(acme.Provisioner)
|
return p, nil
|
||||||
if !ok || pval == nil {
|
|
||||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
|
||||||
}
|
|
||||||
return pval, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
|
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
|
||||||
// pointer to an ACME provisioner or an error.
|
// pointer to an ACME provisioner or an error.
|
||||||
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
|
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
|
||||||
prov, err := provisionerFromContext(ctx)
|
p, err := provisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
acmeProv, ok := prov.(*provisioner.ACME)
|
ap, ok := p.(*provisioner.ACME)
|
||||||
if !ok || acmeProv == nil {
|
if !ok {
|
||||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||||
}
|
}
|
||||||
return acmeProv, nil
|
|
||||||
|
return ap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// payloadFromContext searches the context for a payload. Returns the payload
|
// payloadFromContext searches the context for a payload. Returns the payload
|
||||||
|
|
|
@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write(testBody)
|
w.Write(testBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_baseURLFromRequest(t *testing.T) {
|
func newBaseContext(ctx context.Context, args ...interface{}) context.Context {
|
||||||
tests := []struct {
|
for _, a := range args {
|
||||||
name string
|
switch v := a.(type) {
|
||||||
targetURL string
|
case acme.DB:
|
||||||
expectedResult *url.URL
|
ctx = acme.NewDatabaseContext(ctx, v)
|
||||||
requestPreparer func(*http.Request)
|
case acme.Linker:
|
||||||
}{
|
ctx = acme.NewLinkerContext(ctx, v)
|
||||||
{
|
case acme.PrerequisitesChecker:
|
||||||
"HTTPS host pass-through failed.",
|
ctx = acme.NewPrerequisitesCheckerContext(ctx, v)
|
||||||
"https://my.dummy.host",
|
|
||||||
&url.URL{Scheme: "https", Host: "my.dummy.host"},
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"Port pass-through failed",
|
|
||||||
"https://host.with.port:8080",
|
|
||||||
&url.URL{Scheme: "https", Host: "host.with.port:8080"},
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"Explicit host from Request.Host was not used.",
|
|
||||||
"https://some.target.host:8080",
|
|
||||||
&url.URL{Scheme: "https", Host: "proxied.host"},
|
|
||||||
func(r *http.Request) {
|
|
||||||
r.Host = "proxied.host"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"Missing Request.Host value did not result in empty string result.",
|
|
||||||
"https://some.host",
|
|
||||||
nil,
|
|
||||||
func(r *http.Request) {
|
|
||||||
r.Host = ""
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
request := httptest.NewRequest("GET", tc.targetURL, nil)
|
|
||||||
if tc.requestPreparer != nil {
|
|
||||||
tc.requestPreparer(request)
|
|
||||||
}
|
|
||||||
result := baseURLFromRequest(request)
|
|
||||||
if result == nil || tc.expectedResult == nil {
|
|
||||||
assert.Equals(t, result, tc.expectedResult)
|
|
||||||
} else if result.String() != tc.expectedResult.String() {
|
|
||||||
t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandler_baseURLFromRequest(t *testing.T) {
|
|
||||||
h := &Handler{}
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil)
|
|
||||||
req.Host = "test.ca.smallstep.com:8080"
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
next := func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
bu := baseURLFromContext(r.Context())
|
|
||||||
if assert.NotNil(t, bu) {
|
|
||||||
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
|
|
||||||
assert.Equals(t, bu.Scheme, "https")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return ctx
|
||||||
h.baseURLFromRequest(next)(w, req)
|
|
||||||
|
|
||||||
req = httptest.NewRequest("GET", "/foo", nil)
|
|
||||||
req.Host = ""
|
|
||||||
|
|
||||||
next = func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
assert.Equals(t, baseURLFromContext(r.Context()), nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.baseURLFromRequest(next)(w, req)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_addNonce(t *testing.T) {
|
func TestHandler_addNonce(t *testing.T) {
|
||||||
|
@ -139,10 +74,10 @@ func TestHandler_addNonce(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
ctx := newBaseContext(context.Background(), tc.db)
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil).WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.addNonce(testNext)(w, req)
|
addNonce(testNext)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) {
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
type test struct {
|
type test struct {
|
||||||
link string
|
link string
|
||||||
linker Linker
|
|
||||||
statusCode int
|
statusCode int
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
err *acme.Error
|
err *acme.Error
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
|
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
|
@ -195,11 +128,10 @@ func TestHandler_addDirLink(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: tc.linker}
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil)
|
req := httptest.NewRequest("GET", "/foo", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.addDirLink(testNext)(w, req)
|
addDirLink(testNext)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
||||||
type test struct {
|
type test struct {
|
||||||
h Handler
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
contentType string
|
contentType string
|
||||||
err *acme.Error
|
err *acme.Error
|
||||||
|
@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/provisioner-not-set": func(t *testing.T) test {
|
"fail/provisioner-not-set": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
h: Handler{
|
|
||||||
linker: NewLinker("dns", "acme"),
|
|
||||||
},
|
|
||||||
url: u,
|
url: u,
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
contentType: "foo",
|
contentType: "foo",
|
||||||
|
@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/general-bad-content-type": func(t *testing.T) test {
|
"fail/general-bad-content-type": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
h: Handler{
|
|
||||||
linker: NewLinker("dns", "acme"),
|
|
||||||
},
|
|
||||||
url: u,
|
url: u,
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
contentType: "foo",
|
contentType: "foo",
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
|
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
|
||||||
|
@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/certificate-bad-content-type": func(t *testing.T) test {
|
"fail/certificate-bad-content-type": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
h: Handler{
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
linker: NewLinker("dns", "acme"),
|
|
||||||
},
|
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
||||||
contentType: "foo",
|
contentType: "foo",
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
|
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
|
||||||
|
@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
h: Handler{
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
linker: NewLinker("dns", "acme"),
|
|
||||||
},
|
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
||||||
contentType: "application/jose+json",
|
contentType: "application/jose+json",
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/certificate/pkix-cert": func(t *testing.T) test {
|
"ok/certificate/pkix-cert": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
h: Handler{
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
linker: NewLinker("dns", "acme"),
|
|
||||||
},
|
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
||||||
contentType: "application/pkix-cert",
|
contentType: "application/pkix-cert",
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/certificate/jose+json": func(t *testing.T) test {
|
"ok/certificate/jose+json": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
h: Handler{
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
linker: NewLinker("dns", "acme"),
|
|
||||||
},
|
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
||||||
contentType: "application/jose+json",
|
contentType: "application/jose+json",
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/certificate/pkcs7-mime": func(t *testing.T) test {
|
"ok/certificate/pkcs7-mime": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
h: Handler{
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
linker: NewLinker("dns", "acme"),
|
|
||||||
},
|
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
||||||
contentType: "application/pkcs7-mime",
|
contentType: "application/pkcs7-mime",
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
|
@ -326,7 +236,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
req.Header.Add("Content-Type", tc.contentType)
|
req.Header.Add("Content-Type", tc.contentType)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
tc.h.verifyContentType(testNext)(w, req)
|
verifyContentType(testNext)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -390,11 +300,11 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
// h := &Handler{}
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.isPostAsGet(testNext)(w, req)
|
isPostAsGet(testNext)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -481,10 +391,10 @@ func TestHandler_parseJWS(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
// h := &Handler{}
|
||||||
req := httptest.NewRequest("GET", u, tc.body)
|
req := httptest.NewRequest("GET", u, tc.body)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.parseJWS(tc.next)(w, req)
|
parseJWS(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -679,11 +589,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
// h := &Handler{}
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
parsedJWS, err := jose.ParseJWS(raw)
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
type test struct {
|
type test struct {
|
||||||
linker Linker
|
linker acme.Linker
|
||||||
db acme.DB
|
db acme.DB
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
next func(http.ResponseWriter, *http.Request)
|
next func(http.ResponseWriter, *http.Request)
|
||||||
|
@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-jws": func(t *testing.T) test {
|
"fail/no-jws": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
db: &acme.MockDB{},
|
||||||
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jws expected in request context"),
|
err: acme.NewErrorISE("jws expected in request context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-jws": func(t *testing.T) test {
|
"fail/nil-jws": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jws expected in request context"),
|
err: acme.NewErrorISE("jws expected in request context"),
|
||||||
|
@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
_jws, err := _signer.Sign([]byte("baz"))
|
_jws, err := _signer.Sign([]byte("baz"))
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
db: &acme.MockDB{},
|
||||||
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
|
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
|
||||||
|
@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
_parsed, err := jose.ParseJWS(_raw)
|
_parsed, err := jose.ParseJWS(_raw)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, _parsed)
|
ctx = context.WithValue(ctx, jwsContextKey, _parsed)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
db: &acme.MockDB{},
|
||||||
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
|
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/account-not-found": func(t *testing.T) test {
|
"fail/account-not-found": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||||
assert.Equals(t, accID, accID)
|
assert.Equals(t, accID, accID)
|
||||||
|
@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/GetAccount-error": func(t *testing.T) test {
|
"fail/GetAccount-error": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||||
assert.Equals(t, id, accID)
|
assert.Equals(t, id, accID)
|
||||||
|
@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/account-not-valid": func(t *testing.T) test {
|
"fail/account-not-valid": func(t *testing.T) test {
|
||||||
acc := &acme.Account{Status: "deactivated"}
|
acc := &acme.Account{Status: "deactivated"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||||
assert.Equals(t, id, accID)
|
assert.Equals(t, id, accID)
|
||||||
|
@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{Status: "valid", Key: jwk}
|
acc := &acme.Account{Status: "valid", Key: jwk}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||||
assert.Equals(t, id, accID)
|
assert.Equals(t, id, accID)
|
||||||
|
@ -881,11 +791,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: tc.linker}
|
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.lookupJWK(tc.next)(w, req)
|
lookupJWK(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-jws": func(t *testing.T) test {
|
"fail/no-jws": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
db: &acme.MockDB{},
|
||||||
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jws expected in request context"),
|
err: acme.NewErrorISE("jws expected in request context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-jws": func(t *testing.T) test {
|
"fail/nil-jws": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jws expected in request context"),
|
err: acme.NewErrorISE("jws expected in request context"),
|
||||||
|
@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
|
err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
|
||||||
|
@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
|
err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/GetAccountByKey-error": func(t *testing.T) test {
|
"fail/GetAccountByKey-error": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/account-not-valid": func(t *testing.T) test {
|
"fail/account-not-valid": func(t *testing.T) test {
|
||||||
acc := &acme.Account{Status: "deactivated"}
|
acc := &acme.Account{Status: "deactivated"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{Status: "valid"}
|
acc := &acme.Account{Status: "valid"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/no-account": func(t *testing.T) test {
|
"ok/no-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -1077,11 +991,11 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
ctx := newBaseContext(tc.ctx, tc.db)
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.extractJWK(tc.next)(w, req)
|
extractJWK(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-jws": func(t *testing.T) test {
|
"fail/no-jws": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jws expected in request context"),
|
err: acme.NewErrorISE("jws expected in request context"),
|
||||||
|
@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-jws": func(t *testing.T) test {
|
"fail/nil-jws": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), jwsContextKey, nil),
|
ctx: context.WithValue(context.Background(), jwsContextKey, nil),
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jws expected in request context"),
|
err: acme.NewErrorISE("jws expected in request context"),
|
||||||
|
@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/no-signature": func(t *testing.T) test {
|
"fail/no-signature": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
|
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
|
err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
|
||||||
|
@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
|
err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
|
||||||
|
@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
|
err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
|
||||||
|
@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
|
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
|
||||||
|
@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
|
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
|
||||||
|
@ -1444,11 +1365,11 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
ctx := newBaseContext(tc.ctx, tc.db)
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.validateJWS(tc.next)(w, req)
|
validateJWS(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
||||||
u := "https://ca.smallstep.com/acme/account"
|
u := "https://ca.smallstep.com/acme/account"
|
||||||
type test struct {
|
type test struct {
|
||||||
db acme.DB
|
db acme.DB
|
||||||
linker Linker
|
linker acme.Linker
|
||||||
statusCode int
|
statusCode int
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
err *acme.Error
|
err *acme.Error
|
||||||
|
@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
||||||
parsedJWS, err := jose.ParseJWS(raw)
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: acme.NewLinker("dns", "acme"),
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
||||||
assert.Equals(t, kid, pub.KeyID)
|
assert.Equals(t, kid, pub.KeyID)
|
||||||
|
@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
||||||
parsedJWS, err := jose.ParseJWS(raw)
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
|
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("test.ca.smallstep.com", "acme"),
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||||
assert.Equals(t, accID, acc.ID)
|
assert.Equals(t, accID, acc.ID)
|
||||||
|
@ -1628,11 +1548,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: tc.linker}
|
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.extractOrLookupJWK(tc.next)(w, req)
|
extractOrLookupJWK(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
||||||
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||||
baseURL, provName)
|
baseURL, provName)
|
||||||
type test struct {
|
type test struct {
|
||||||
linker Linker
|
linker acme.Linker
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
prerequisitesChecker func(context.Context) (bool, error)
|
prerequisitesChecker func(context.Context) (bool, error)
|
||||||
next func(http.ResponseWriter, *http.Request)
|
next func(http.ResponseWriter, *http.Request)
|
||||||
|
@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/error": func(t *testing.T) test {
|
"fail/error": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: acme.NewLinker("dns", "acme"),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
|
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
|
||||||
next: func(w http.ResponseWriter, r *http.Request) {
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/prerequisites-nok": func(t *testing.T) test {
|
"fail/prerequisites-nok": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: acme.NewLinker("dns", "acme"),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
|
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
|
||||||
next: func(w http.ResponseWriter, r *http.Request) {
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: acme.NewLinker("dns", "acme"),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
|
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
|
||||||
next: func(w http.ResponseWriter, r *http.Request) {
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -1717,11 +1634,11 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
|
ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker)
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.checkPrerequisites(tc.next)(w, req)
|
checkPrerequisites(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
|
@ -72,8 +72,12 @@ var defaultOrderExpiry = time.Hour * 24
|
||||||
var defaultOrderBackdate = time.Minute
|
var defaultOrderBackdate = time.Minute
|
||||||
|
|
||||||
// NewOrder ACME api for creating a new order.
|
// NewOrder ACME api for creating a new order.
|
||||||
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
ca := mustAuthority(ctx)
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
|
@ -113,7 +117,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
var eak *acme.ExternalAccountKey
|
var eak *acme.ExternalAccountKey
|
||||||
if acmeProv.RequireEAB {
|
if acmeProv.RequireEAB {
|
||||||
if eak, err = h.db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
|
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -138,7 +142,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// evaluate the authority level policy
|
// evaluate the authority level policy
|
||||||
if err = h.ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
|
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
|
||||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -164,7 +168,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ExpiresAt: o.ExpiresAt,
|
ExpiresAt: o.ExpiresAt,
|
||||||
Status: acme.StatusPending,
|
Status: acme.StatusPending,
|
||||||
}
|
}
|
||||||
if err := h.newAuthorization(ctx, az); err != nil {
|
if err := newAuthorization(ctx, az); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -183,14 +187,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
|
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.CreateOrder(ctx, o); err != nil {
|
if err := db.CreateOrder(ctx, o); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
|
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkOrder(ctx, o)
|
linker.LinkOrder(ctx, o)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||||
render.JSONStatus(w, o, http.StatusCreated)
|
render.JSONStatus(w, o, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,7 +212,7 @@ func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error
|
||||||
return policy.NewX509PolicyEngine(eak.Policy)
|
return policy.NewX509PolicyEngine(eak.Policy)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
func newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||||
if strings.HasPrefix(az.Identifier.Value, "*.") {
|
if strings.HasPrefix(az.Identifier.Value, "*.") {
|
||||||
az.Wildcard = true
|
az.Wildcard = true
|
||||||
az.Identifier = acme.Identifier{
|
az.Identifier = acme.Identifier{
|
||||||
|
@ -224,6 +228,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
|
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
az.Challenges = make([]*acme.Challenge, len(chTypes))
|
az.Challenges = make([]*acme.Challenge, len(chTypes))
|
||||||
for i, typ := range chTypes {
|
for i, typ := range chTypes {
|
||||||
ch := &acme.Challenge{
|
ch := &acme.Challenge{
|
||||||
|
@ -233,20 +239,23 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
|
||||||
Token: az.Token,
|
Token: az.Token,
|
||||||
Status: acme.StatusPending,
|
Status: acme.StatusPending,
|
||||||
}
|
}
|
||||||
if err := h.db.CreateChallenge(ctx, ch); err != nil {
|
if err := db.CreateChallenge(ctx, ch); err != nil {
|
||||||
return acme.WrapErrorISE(err, "error creating challenge")
|
return acme.WrapErrorISE(err, "error creating challenge")
|
||||||
}
|
}
|
||||||
az.Challenges[i] = ch
|
az.Challenges[i] = ch
|
||||||
}
|
}
|
||||||
if err = h.db.CreateAuthorization(ctx, az); err != nil {
|
if err = db.CreateAuthorization(ctx, az); err != nil {
|
||||||
return acme.WrapErrorISE(err, "error creating authorization")
|
return acme.WrapErrorISE(err, "error creating authorization")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrder ACME api for retrieving an order.
|
// GetOrder ACME api for retrieving an order.
|
||||||
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
func GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
|
@ -257,7 +266,8 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
|
||||||
|
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||||
return
|
return
|
||||||
|
@ -272,20 +282,23 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = o.UpdateStatus(ctx, h.db); err != nil {
|
if err = o.UpdateStatus(ctx, db); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
|
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkOrder(ctx, o)
|
linker.LinkOrder(ctx, o)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||||
render.JSON(w, o)
|
render.JSON(w, o)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
// FinalizeOrder attempts to finalize an order and create a certificate.
|
||||||
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
|
@ -312,7 +325,7 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||||
return
|
return
|
||||||
|
@ -327,14 +340,16 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
|
|
||||||
|
ca := mustAuthority(ctx)
|
||||||
|
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
|
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkOrder(ctx, o)
|
linker.LinkOrder(ctx, o)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||||
render.JSON(w, o)
|
render.JSON(w, o)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -280,15 +280,17 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
db: &acme.MockDB{},
|
||||||
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
|
@ -298,6 +300,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("provisioner does not exist"),
|
err: acme.NewErrorISE("provisioner does not exist"),
|
||||||
|
@ -305,9 +308,10 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-provisioner": func(t *testing.T) test {
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
ctx := acme.NewProvisionerContext(context.Background(), nil)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("provisioner does not exist"),
|
err: acme.NewErrorISE("provisioner does not exist"),
|
||||||
|
@ -315,7 +319,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/db.GetOrder-error": func(t *testing.T) test {
|
"fail/db.GetOrder-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
|
@ -329,7 +333,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
|
@ -345,7 +349,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/provisioner-id-mismatch": func(t *testing.T) test {
|
"fail/provisioner-id-mismatch": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
|
@ -361,7 +365,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/order-update-error": func(t *testing.T) test {
|
"fail/order-update-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
|
@ -385,10 +389,9 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
|
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
|
||||||
|
@ -425,11 +428,11 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetOrder(w, req)
|
GetOrder(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -640,8 +643,8 @@ func TestHandler_newAuthorization(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
h := &Handler{db: tc.db}
|
ctx := newBaseContext(context.Background(), tc.db)
|
||||||
if err := h.newAuthorization(context.Background(), tc.az); err != nil {
|
if err := newAuthorization(ctx, tc.az); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
|
@ -682,15 +685,17 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
db: &acme.MockDB{},
|
||||||
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
|
@ -700,6 +705,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("provisioner does not exist"),
|
err: acme.NewErrorISE("provisioner does not exist"),
|
||||||
|
@ -707,9 +713,10 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-provisioner": func(t *testing.T) test {
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("provisioner does not exist"),
|
err: acme.NewErrorISE("provisioner does not exist"),
|
||||||
|
@ -718,8 +725,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
"fail/no-payload": func(t *testing.T) test {
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload does not exist"),
|
err: acme.NewErrorISE("payload does not exist"),
|
||||||
|
@ -727,21 +735,23 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-payload": func(t *testing.T) test {
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("paylod does not exist"),
|
err: acme.NewErrorISE("payload does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"),
|
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"),
|
||||||
|
@ -752,10 +762,11 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
fr := &NewOrderRequest{}
|
fr := &NewOrderRequest{}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
|
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
|
||||||
|
@ -770,7 +781,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, &acme.MockProvisioner{})
|
ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{})
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
@ -798,7 +809,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv)
|
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
@ -826,7 +837,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv)
|
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
@ -862,7 +873,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv)
|
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
@ -905,7 +916,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy)
|
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
@ -948,7 +959,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy)
|
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
@ -986,7 +997,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
@ -1020,7 +1031,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
var (
|
var (
|
||||||
|
@ -1096,10 +1107,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nor)
|
b, err := json.Marshal(nor)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
var (
|
var (
|
||||||
ch1, ch2, ch3, ch4 **acme.Challenge
|
ch1, ch2, ch3, ch4 **acme.Challenge
|
||||||
az1ID, az2ID *string
|
az1ID, az2ID *string
|
||||||
|
@ -1217,10 +1227,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nor)
|
b, err := json.Marshal(nor)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
var (
|
var (
|
||||||
ch1, ch2, ch3 **acme.Challenge
|
ch1, ch2, ch3 **acme.Challenge
|
||||||
az1ID *string
|
az1ID *string
|
||||||
|
@ -1315,10 +1324,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nor)
|
b, err := json.Marshal(nor)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
var (
|
var (
|
||||||
ch1, ch2, ch3 **acme.Challenge
|
ch1, ch2, ch3 **acme.Challenge
|
||||||
az1ID *string
|
az1ID *string
|
||||||
|
@ -1412,10 +1420,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nor)
|
b, err := json.Marshal(nor)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
var (
|
var (
|
||||||
ch1, ch2, ch3 **acme.Challenge
|
ch1, ch2, ch3 **acme.Challenge
|
||||||
az1ID *string
|
az1ID *string
|
||||||
|
@ -1510,10 +1517,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nor)
|
b, err := json.Marshal(nor)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
var (
|
var (
|
||||||
ch1, ch2, ch3 **acme.Challenge
|
ch1, ch2, ch3 **acme.Challenge
|
||||||
az1ID *string
|
az1ID *string
|
||||||
|
@ -1611,10 +1617,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nor)
|
b, err := json.Marshal(nor)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy)
|
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
var (
|
var (
|
||||||
ch1, ch2, ch3 **acme.Challenge
|
ch1, ch2, ch3 **acme.Challenge
|
||||||
az1ID *string
|
az1ID *string
|
||||||
|
@ -1701,11 +1706,12 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
|
mockMustAuthority(t, tc.ca)
|
||||||
|
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.NewOrder(w, req)
|
NewOrder(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -1738,6 +1744,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_FinalizeOrder(t *testing.T) {
|
func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
|
mockMustAuthority(t, &mockCA{})
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
escProvName := url.PathEscape(prov.GetName())
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
@ -1796,15 +1803,17 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
db: &acme.MockDB{},
|
||||||
|
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
|
@ -1814,6 +1823,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("provisioner does not exist"),
|
err: acme.NewErrorISE("provisioner does not exist"),
|
||||||
|
@ -1821,9 +1831,10 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-provisioner": func(t *testing.T) test {
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("provisioner does not exist"),
|
err: acme.NewErrorISE("provisioner does not exist"),
|
||||||
|
@ -1832,8 +1843,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
"fail/no-payload": func(t *testing.T) test {
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload does not exist"),
|
err: acme.NewErrorISE("payload does not exist"),
|
||||||
|
@ -1841,21 +1853,23 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-payload": func(t *testing.T) test {
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("paylod does not exist"),
|
err: acme.NewErrorISE("payload does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"),
|
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"),
|
||||||
|
@ -1866,10 +1880,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
fr := &FinalizeRequest{}
|
fr := &FinalizeRequest{}
|
||||||
b, err := json.Marshal(fr)
|
b, err := json.Marshal(fr)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
|
err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
|
||||||
|
@ -1878,7 +1893,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
"fail/db.GetOrder-error": func(t *testing.T) test {
|
"fail/db.GetOrder-error": func(t *testing.T) test {
|
||||||
|
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
@ -1893,7 +1908,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
@ -1910,7 +1925,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/provisioner-id-mismatch": func(t *testing.T) test {
|
"fail/provisioner-id-mismatch": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
@ -1927,7 +1942,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/order-finalize-error": func(t *testing.T) test {
|
"fail/order-finalize-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
@ -1952,10 +1967,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID"}
|
acc := &acme.Account{ID: "accountID"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
|
@ -1991,11 +2005,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||||
req := httptest.NewRequest("GET", u, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.FinalizeOrder(w, req)
|
FinalizeOrder(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
|
@ -26,9 +26,11 @@ type revokePayload struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeCert attempts to revoke a certificate.
|
// RevokeCert attempts to revoke a certificate.
|
||||||
func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
|
||||||
jws, err := jwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
|
@ -69,7 +71,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
serial := certToBeRevoked.SerialNumber.String()
|
serial := certToBeRevoked.SerialNumber.String()
|
||||||
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
|
dbCert, err := db.GetCertificateBySerial(ctx, serial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||||
return
|
return
|
||||||
|
@ -87,7 +89,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||||
if acmeErr != nil {
|
if acmeErr != nil {
|
||||||
render.Error(w, acmeErr)
|
render.Error(w, acmeErr)
|
||||||
return
|
return
|
||||||
|
@ -103,7 +105,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
|
ca := mustAuthority(ctx)
|
||||||
|
hasBeenRevokedBefore, err := ca.IsRevoked(serial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
||||||
return
|
return
|
||||||
|
@ -130,14 +133,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
||||||
err = h.ca.Revoke(ctx, options)
|
err = ca.Revoke(ctx, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, wrapRevokeErr(err))
|
render.Error(w, wrapRevokeErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logRevoke(w, options)
|
logRevoke(w, options)
|
||||||
w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index"))
|
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||||
w.Write(nil)
|
w.Write(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,7 +151,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
|
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
|
||||||
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
|
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
|
||||||
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
|
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
|
||||||
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||||
if !account.IsValid() {
|
if !account.IsValid() {
|
||||||
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
|
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
|
||||||
}
|
}
|
||||||
|
|
|
@ -521,6 +521,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
"fail/no-jws": func(t *testing.T) test {
|
"fail/no-jws": func(t *testing.T) test {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jws expected in request context"),
|
err: acme.NewErrorISE("jws expected in request context"),
|
||||||
|
@ -529,6 +530,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
"fail/nil-jws": func(t *testing.T) test {
|
"fail/nil-jws": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), jwsContextKey, nil)
|
ctx := context.WithValue(context.Background(), jwsContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("jws expected in request context"),
|
err: acme.NewErrorISE("jws expected in request context"),
|
||||||
|
@ -537,6 +539,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
"fail/no-provisioner": func(t *testing.T) test {
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("provisioner does not exist"),
|
err: acme.NewErrorISE("provisioner does not exist"),
|
||||||
|
@ -544,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-provisioner": func(t *testing.T) test {
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, nil)
|
ctx = acme.NewProvisionerContext(ctx, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("provisioner does not exist"),
|
err: acme.NewErrorISE("provisioner does not exist"),
|
||||||
|
@ -553,8 +557,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/no-payload": func(t *testing.T) test {
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload does not exist"),
|
err: acme.NewErrorISE("payload does not exist"),
|
||||||
|
@ -562,9 +567,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/nil-payload": func(t *testing.T) test {
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("payload does not exist"),
|
err: acme.NewErrorISE("payload does not exist"),
|
||||||
|
@ -573,9 +579,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
"fail/unmarshal-payload": func(t *testing.T) test {
|
"fail/unmarshal-payload": func(t *testing.T) test {
|
||||||
malformedPayload := []byte(`{"payload":malformed?}`)
|
malformedPayload := []byte(`{"payload":malformed?}`)
|
||||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload})
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: acme.NewErrorISE("error unmarshaling payload"),
|
err: acme.NewErrorISE("error unmarshaling payload"),
|
||||||
|
@ -587,10 +594,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
}
|
}
|
||||||
wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
|
wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: &acme.Error{
|
err: &acme.Error{
|
||||||
|
@ -606,10 +614,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
}
|
}
|
||||||
emptyPayloadBytes, err := json.Marshal(emptyPayload)
|
emptyPayloadBytes, err := json.Marshal(emptyPayload)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
return test{
|
return test{
|
||||||
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: &acme.Error{
|
err: &acme.Error{
|
||||||
|
@ -620,7 +629,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/db.GetCertificateBySerial": func(t *testing.T) test {
|
"fail/db.GetCertificateBySerial": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
db := &acme.MockDB{
|
db := &acme.MockDB{
|
||||||
|
@ -638,7 +647,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
"fail/different-certificate-contents": func(t *testing.T) test {
|
"fail/different-certificate-contents": func(t *testing.T) test {
|
||||||
aDifferentCert, _, err := generateCertKeyPair()
|
aDifferentCert, _, err := generateCertKeyPair()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
db := &acme.MockDB{
|
db := &acme.MockDB{
|
||||||
|
@ -657,7 +666,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
db := &acme.MockDB{
|
db := &acme.MockDB{
|
||||||
|
@ -676,7 +685,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
|
@ -697,11 +706,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/account-not-valid": func(t *testing.T) test {
|
"fail/account-not-valid": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
|
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
db := &acme.MockDB{
|
db := &acme.MockDB{
|
||||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||||
|
@ -727,11 +735,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/account-not-authorized": func(t *testing.T) test {
|
"fail/account-not-authorized": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
db := &acme.MockDB{
|
db := &acme.MockDB{
|
||||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||||
|
@ -781,10 +788,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
|
unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
db := &acme.MockDB{
|
db := &acme.MockDB{
|
||||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||||
|
@ -808,11 +814,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/certificate-revoked-check-fails": func(t *testing.T) test {
|
"fail/certificate-revoked-check-fails": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
db := &acme.MockDB{
|
db := &acme.MockDB{
|
||||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||||
|
@ -842,7 +847,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/certificate-already-revoked": func(t *testing.T) test {
|
"fail/certificate-already-revoked": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
|
@ -880,7 +885,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
|
invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
|
@ -918,7 +923,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv)
|
ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
|
@ -950,7 +955,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/ca.Revoke": func(t *testing.T) test {
|
"fail/ca.Revoke": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
|
@ -982,7 +987,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/ca.Revoke-already-revoked": func(t *testing.T) test {
|
"fail/ca.Revoke-already-revoked": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
|
@ -1013,11 +1018,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok/using-account-key": func(t *testing.T) test {
|
"ok/using-account-key": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
db := &acme.MockDB{
|
db := &acme.MockDB{
|
||||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||||
|
@ -1041,10 +1045,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
jws, err := jose.ParseJWS(string(jwsBytes))
|
jws, err := jose.ParseJWS(string(jwsBytes))
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
db := &acme.MockDB{
|
db := &acme.MockDB{
|
||||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||||
|
@ -1067,11 +1070,12 @@ func TestHandler_RevokeCert(t *testing.T) {
|
||||||
for name, setup := range tests {
|
for name, setup := range tests {
|
||||||
tc := setup(t)
|
tc := setup(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
|
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||||
|
mockMustAuthority(t, tc.ca)
|
||||||
req := httptest.NewRequest("POST", revokeURL, nil)
|
req := httptest.NewRequest("POST", revokeURL, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.RevokeCert(w, req)
|
RevokeCert(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -1208,8 +1212,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
|
||||||
for name, setup := range tests {
|
for name, setup := range tests {
|
||||||
tc := setup(t)
|
tc := setup(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
// h := &Handler{db: tc.db}
|
||||||
acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
|
acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
|
||||||
|
|
||||||
expectError := tc.err != nil
|
expectError := tc.err != nil
|
||||||
gotError := acmeErr != nil
|
gotError := acmeErr != nil
|
||||||
|
|
|
@ -14,7 +14,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) {
|
||||||
// type using the DB interface.
|
// type using the DB interface.
|
||||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||||
// updated.
|
// updated.
|
||||||
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error {
|
||||||
// If already valid or invalid then return without performing validation.
|
// If already valid or invalid then return without performing validation.
|
||||||
if ch.Status != StatusPending {
|
if ch.Status != StatusPending {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
switch ch.Type {
|
switch ch.Type {
|
||||||
case HTTP01:
|
case HTTP01:
|
||||||
return http01Validate(ctx, ch, db, jwk, vo)
|
return http01Validate(ctx, ch, db, jwk)
|
||||||
case DNS01:
|
case DNS01:
|
||||||
return dns01Validate(ctx, ch, db, jwk, vo)
|
return dns01Validate(ctx, ch, db, jwk)
|
||||||
case TLSALPN01:
|
case TLSALPN01:
|
||||||
return tlsalpn01Validate(ctx, ch, db, jwk, vo)
|
return tlsalpn01Validate(ctx, ch, db, jwk)
|
||||||
default:
|
default:
|
||||||
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
|
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||||
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||||
|
|
||||||
resp, err := vo.HTTPGet(u.String())
|
vc := MustClientFromContext(ctx)
|
||||||
|
resp, err := vc.Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||||
"error doing http GET for url %s", u))
|
"error doing http GET for url %s", u))
|
||||||
|
@ -141,7 +141,7 @@ func tlsAlert(err error) uint8 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||||
config := &tls.Config{
|
config := &tls.Config{
|
||||||
NextProtos: []string{"acme-tls/1"},
|
NextProtos: []string{"acme-tls/1"},
|
||||||
// https://tools.ietf.org/html/rfc8737#section-4
|
// https://tools.ietf.org/html/rfc8737#section-4
|
||||||
|
@ -154,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
||||||
|
|
||||||
hostPort := net.JoinHostPort(ch.Value, "443")
|
hostPort := net.JoinHostPort(ch.Value, "443")
|
||||||
|
|
||||||
conn, err := vo.TLSDial("tcp", hostPort, config)
|
vc := MustClientFromContext(ctx)
|
||||||
|
conn, err := vc.TLSDial("tcp", hostPort, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
|
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
|
||||||
// client and server protocols. When this happens the connection is
|
// client and server protocols. When this happens the connection is
|
||||||
|
@ -253,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
||||||
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
|
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||||
// Normalize domain for wildcard DNS names
|
// Normalize domain for wildcard DNS names
|
||||||
// This is done to avoid making TXT lookups for domains like
|
// This is done to avoid making TXT lookups for domains like
|
||||||
// _acme-challenge.*.example.com
|
// _acme-challenge.*.example.com
|
||||||
// Instead perform txt lookup for _acme-challenge.example.com
|
// Instead perform txt lookup for _acme-challenge.example.com
|
||||||
domain := strings.TrimPrefix(ch.Value, "*.")
|
domain := strings.TrimPrefix(ch.Value, "*.")
|
||||||
|
|
||||||
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
|
vc := MustClientFromContext(ctx)
|
||||||
|
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
|
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
|
||||||
"error looking up TXT records for domain %s", domain))
|
"error looking up TXT records for domain %s", domain))
|
||||||
|
@ -376,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpGetter func(string) (*http.Response, error)
|
|
||||||
type lookupTxt func(string) ([]string, error)
|
|
||||||
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
|
||||||
|
|
||||||
// ValidateChallengeOptions are ACME challenge validator functions.
|
|
||||||
type ValidateChallengeOptions struct {
|
|
||||||
HTTPGet httpGetter
|
|
||||||
LookupTxt lookupTxt
|
|
||||||
TLSDial tlsDialer
|
|
||||||
}
|
|
||||||
|
|
|
@ -29,6 +29,18 @@ import (
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mockClient struct {
|
||||||
|
get func(url string) (*http.Response, error)
|
||||||
|
lookupTxt func(name string) ([]string, error)
|
||||||
|
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) }
|
||||||
|
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
|
||||||
|
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
|
return m.tlsDial(network, addr, config)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_storeError(t *testing.T) {
|
func Test_storeError(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
ch *Challenge
|
ch *Challenge
|
||||||
|
@ -229,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) {
|
||||||
func TestChallenge_Validate(t *testing.T) {
|
func TestChallenge_Validate(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
ch *Challenge
|
ch *Challenge
|
||||||
vo *ValidateChallengeOptions
|
vc Client
|
||||||
jwk *jose.JSONWebKey
|
jwk *jose.JSONWebKey
|
||||||
db DB
|
db DB
|
||||||
srv *httptest.Server
|
srv *httptest.Server
|
||||||
|
@ -273,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -309,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -344,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
LookupTxt: func(url string) ([]string, error) {
|
lookupTxt: func(url string) ([]string, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -381,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
LookupTxt: func(url string) ([]string, error) {
|
lookupTxt: func(url string) ([]string, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -416,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) {
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -466,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -493,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) {
|
||||||
defer tc.srv.Close()
|
defer tc.srv.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil {
|
ctx := NewClientContext(context.Background(), tc.vc)
|
||||||
|
if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *Error:
|
case *Error:
|
||||||
|
@ -524,7 +537,7 @@ func (errReader) Close() error {
|
||||||
|
|
||||||
func TestHTTP01Validate(t *testing.T) {
|
func TestHTTP01Validate(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
vo *ValidateChallengeOptions
|
vc Client
|
||||||
ch *Challenge
|
ch *Challenge
|
||||||
jwk *jose.JSONWebKey
|
jwk *jose.JSONWebKey
|
||||||
db DB
|
db DB
|
||||||
|
@ -541,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -575,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -608,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
StatusCode: http.StatusBadRequest,
|
StatusCode: http.StatusBadRequest,
|
||||||
Body: errReader(0),
|
Body: errReader(0),
|
||||||
|
@ -645,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
StatusCode: http.StatusBadRequest,
|
StatusCode: http.StatusBadRequest,
|
||||||
Body: errReader(0),
|
Body: errReader(0),
|
||||||
|
@ -681,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
Body: errReader(0),
|
Body: errReader(0),
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -704,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
jwk.Key = "foo"
|
jwk.Key = "foo"
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -730,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -772,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -815,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -857,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
HTTPGet: func(url string) (*http.Response, error) {
|
get: func(url string) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -887,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
ctx := NewClientContext(context.Background(), tc.vc)
|
||||||
|
if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *Error:
|
case *Error:
|
||||||
|
@ -911,7 +925,7 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
fulldomain := "*.zap.internal"
|
fulldomain := "*.zap.internal"
|
||||||
domain := strings.TrimPrefix(fulldomain, "*.")
|
domain := strings.TrimPrefix(fulldomain, "*.")
|
||||||
type test struct {
|
type test struct {
|
||||||
vo *ValidateChallengeOptions
|
vc Client
|
||||||
ch *Challenge
|
ch *Challenge
|
||||||
jwk *jose.JSONWebKey
|
jwk *jose.JSONWebKey
|
||||||
db DB
|
db DB
|
||||||
|
@ -928,8 +942,8 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
LookupTxt: func(url string) ([]string, error) {
|
lookupTxt: func(url string) ([]string, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -963,8 +977,8 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
LookupTxt: func(url string) ([]string, error) {
|
lookupTxt: func(url string) ([]string, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1001,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
LookupTxt: func(url string) ([]string, error) {
|
lookupTxt: func(url string) ([]string, error) {
|
||||||
return []string{"foo"}, nil
|
return []string{"foo"}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1026,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
LookupTxt: func(url string) ([]string, error) {
|
lookupTxt: func(url string) ([]string, error) {
|
||||||
return []string{"foo", "bar"}, nil
|
return []string{"foo", "bar"}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1068,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
LookupTxt: func(url string) ([]string, error) {
|
lookupTxt: func(url string) ([]string, error) {
|
||||||
return []string{"foo", "bar"}, nil
|
return []string{"foo", "bar"}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1111,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
LookupTxt: func(url string) ([]string, error) {
|
lookupTxt: func(url string) ([]string, error) {
|
||||||
return []string{"foo", expected}, nil
|
return []string{"foo", expected}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1156,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
LookupTxt: func(url string) ([]string, error) {
|
lookupTxt: func(url string) ([]string, error) {
|
||||||
return []string{"foo", expected}, nil
|
return []string{"foo", expected}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1186,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
ctx := NewClientContext(context.Background(), tc.vc)
|
||||||
|
if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *Error:
|
case *Error:
|
||||||
|
@ -1206,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error)
|
||||||
|
|
||||||
func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) {
|
func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) {
|
||||||
srv := httptest.NewUnstartedServer(http.NewServeMux())
|
srv := httptest.NewUnstartedServer(http.NewServeMux())
|
||||||
|
|
||||||
|
@ -1309,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
type test struct {
|
type test struct {
|
||||||
vo *ValidateChallengeOptions
|
vc Client
|
||||||
ch *Challenge
|
ch *Challenge
|
||||||
jwk *jose.JSONWebKey
|
jwk *jose.JSONWebKey
|
||||||
db DB
|
db DB
|
||||||
|
@ -1321,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
ch := makeTLSCh()
|
ch := makeTLSCh()
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1351,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
ch := makeTLSCh()
|
ch := makeTLSCh()
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1384,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1413,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
return tls.Client(&noopConn{}, config), nil
|
return tls.Client(&noopConn{}, config), nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1443,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
return tls.Client(&noopConn{}, config), nil
|
return tls.Client(&noopConn{}, config), nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1479,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1516,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1562,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1605,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1649,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1692,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1736,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
srv: srv,
|
srv: srv,
|
||||||
jwk: jwk,
|
jwk: jwk,
|
||||||
|
@ -1758,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1797,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1841,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1884,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1924,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -1963,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -2008,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -2054,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -2100,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -2144,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -2189,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -2226,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ch: ch,
|
ch: ch,
|
||||||
vo: &ValidateChallengeOptions{
|
vc: &mockClient{
|
||||||
TLSDial: tlsDial,
|
tlsDial: tlsDial,
|
||||||
},
|
},
|
||||||
db: &MockDB{
|
db: &MockDB{
|
||||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||||
|
@ -2253,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
defer tc.srv.Close()
|
defer tc.srv.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
ctx := NewClientContext(context.Background(), tc.vc)
|
||||||
|
if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *Error:
|
case *Error:
|
||||||
|
|
79
acme/client.go
Normal file
79
acme/client.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client is the interface used to verify ACME challenges.
|
||||||
|
type Client interface {
|
||||||
|
// Get issues an HTTP GET to the specified URL.
|
||||||
|
Get(url string) (*http.Response, error)
|
||||||
|
|
||||||
|
// LookupTXT returns the DNS TXT records for the given domain name.
|
||||||
|
LookupTxt(name string) ([]string, error)
|
||||||
|
|
||||||
|
// TLSDial connects to the given network address using net.Dialer and then
|
||||||
|
// initiates a TLS handshake, returning the resulting TLS connection.
|
||||||
|
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientKey struct{}
|
||||||
|
|
||||||
|
// NewClientContext adds the given client to the context.
|
||||||
|
func NewClientContext(ctx context.Context, c Client) context.Context {
|
||||||
|
return context.WithValue(ctx, clientKey{}, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientFromContext returns the current client from the given context.
|
||||||
|
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
|
||||||
|
c, ok = ctx.Value(clientKey{}).(Client)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustClientFromContext returns the current client from the given context. It will
|
||||||
|
// return a new instance of the client if it does not exist.
|
||||||
|
func MustClientFromContext(ctx context.Context) Client {
|
||||||
|
c, ok := ClientFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return NewClient()
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
type client struct {
|
||||||
|
http *http.Client
|
||||||
|
dialer *net.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient returns an implementation of Client for verifying ACME challenges.
|
||||||
|
func NewClient() Client {
|
||||||
|
return &client{
|
||||||
|
http: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dialer: &net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) Get(url string) (*http.Response, error) {
|
||||||
|
return c.http.Get(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) LookupTxt(name string) ([]string, error) {
|
||||||
|
return net.LookupTXT(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
|
return tls.DialWithDialer(c.dialer, network, addr, config)
|
||||||
|
}
|
|
@ -9,15 +9,6 @@ import (
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CertificateAuthority is the interface implemented by a CA authority.
|
|
||||||
type CertificateAuthority interface {
|
|
||||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
|
||||||
AreSANsAllowed(ctx context.Context, sans []string) error
|
|
||||||
IsRevoked(sn string) (bool, error)
|
|
||||||
Revoke(context.Context, *authority.RevokeOptions) error
|
|
||||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clock that returns time in UTC rounded to seconds.
|
// Clock that returns time in UTC rounded to seconds.
|
||||||
type Clock struct{}
|
type Clock struct{}
|
||||||
|
|
||||||
|
@ -28,6 +19,52 @@ func (c *Clock) Now() time.Time {
|
||||||
|
|
||||||
var clock Clock
|
var clock Clock
|
||||||
|
|
||||||
|
// CertificateAuthority is the interface implemented by a CA authority.
|
||||||
|
type CertificateAuthority interface {
|
||||||
|
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||||
|
AreSANsAllowed(ctx context.Context, sans []string) error
|
||||||
|
IsRevoked(sn string) (bool, error)
|
||||||
|
Revoke(context.Context, *authority.RevokeOptions) error
|
||||||
|
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContext adds the given acme components to the context.
|
||||||
|
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
|
||||||
|
ctx = NewDatabaseContext(ctx, db)
|
||||||
|
ctx = NewClientContext(ctx, client)
|
||||||
|
ctx = NewLinkerContext(ctx, linker)
|
||||||
|
// Prerequisite checker is optional.
|
||||||
|
if fn != nil {
|
||||||
|
ctx = NewPrerequisitesCheckerContext(ctx, fn)
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrerequisitesChecker is a function that checks if all prerequisites for
|
||||||
|
// serving ACME are met by the CA configuration.
|
||||||
|
type PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||||
|
|
||||||
|
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
|
||||||
|
// always true.
|
||||||
|
func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type prerequisitesKey struct{}
|
||||||
|
|
||||||
|
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
|
||||||
|
// context.
|
||||||
|
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
|
||||||
|
return context.WithValue(ctx, prerequisitesKey{}, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
|
||||||
|
// context.
|
||||||
|
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
|
||||||
|
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
|
||||||
|
return fn, ok && fn != nil
|
||||||
|
}
|
||||||
|
|
||||||
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
||||||
// only those methods required by the ACME api/authority.
|
// only those methods required by the ACME api/authority.
|
||||||
type Provisioner interface {
|
type Provisioner interface {
|
||||||
|
@ -40,6 +77,29 @@ type Provisioner interface {
|
||||||
GetOptions() *provisioner.Options
|
GetOptions() *provisioner.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type provisionerKey struct{}
|
||||||
|
|
||||||
|
// NewProvisionerContext adds the given provisioner to the context.
|
||||||
|
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
|
||||||
|
return context.WithValue(ctx, provisionerKey{}, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvisionerFromContext returns the current provisioner from the given context.
|
||||||
|
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
|
||||||
|
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustLinkerFromContext returns the current provisioner from the given context.
|
||||||
|
// It will panic if it's not in the context.
|
||||||
|
func MustProvisionerFromContext(ctx context.Context) Provisioner {
|
||||||
|
if v, ok := ProvisionerFromContext(ctx); !ok {
|
||||||
|
panic("acme provisioner is not the context")
|
||||||
|
} else {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MockProvisioner for testing
|
// MockProvisioner for testing
|
||||||
type MockProvisioner struct {
|
type MockProvisioner struct {
|
||||||
Mret1 interface{}
|
Mret1 interface{}
|
||||||
|
|
23
acme/db.go
23
acme/db.go
|
@ -49,6 +49,29 @@ type DB interface {
|
||||||
UpdateOrder(ctx context.Context, o *Order) error
|
UpdateOrder(ctx context.Context, o *Order) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dbKey struct{}
|
||||||
|
|
||||||
|
// NewDatabaseContext adds the given acme database to the context.
|
||||||
|
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
|
||||||
|
return context.WithValue(ctx, dbKey{}, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabaseFromContext returns the current acme database from the given context.
|
||||||
|
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
|
||||||
|
db, ok = ctx.Value(dbKey{}).(DB)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustDatabaseFromContext returns the current database from the given context.
|
||||||
|
// It will panic if it's not in the context.
|
||||||
|
func MustDatabaseFromContext(ctx context.Context) DB {
|
||||||
|
if db, ok := DatabaseFromContext(ctx); !ok {
|
||||||
|
panic("acme database is not in the context")
|
||||||
|
} else {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MockDB is an implementation of the DB interface that should only be used as
|
// MockDB is an implementation of the DB interface that should only be used as
|
||||||
// a mock in tests.
|
// a mock in tests.
|
||||||
type MockDB struct {
|
type MockDB struct {
|
||||||
|
|
|
@ -1,100 +1,19 @@
|
||||||
package api
|
package acme
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
|
"github.com/smallstep/certificates/authority"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewLinker returns a new Directory type.
|
|
||||||
func NewLinker(dns, prefix string) Linker {
|
|
||||||
_, _, err := net.SplitHostPort(dns)
|
|
||||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
|
||||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
|
||||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
|
||||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
|
||||||
// these cases, then the input dns is not changed.
|
|
||||||
lastIndex := strings.LastIndex(dns, ":")
|
|
||||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
|
||||||
if ip := net.ParseIP(hostPart); ip != nil {
|
|
||||||
dns = "[" + hostPart + "]:" + portPart
|
|
||||||
} else if ip := net.ParseIP(dns); ip != nil {
|
|
||||||
dns = "[" + dns + "]"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &linker{prefix: prefix, dns: dns}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Linker interface for generating links for ACME resources.
|
|
||||||
type Linker interface {
|
|
||||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
|
||||||
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
|
|
||||||
|
|
||||||
LinkOrder(ctx context.Context, o *acme.Order)
|
|
||||||
LinkAccount(ctx context.Context, o *acme.Account)
|
|
||||||
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
|
|
||||||
LinkAuthorization(ctx context.Context, o *acme.Authorization)
|
|
||||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// linker generates ACME links.
|
|
||||||
type linker struct {
|
|
||||||
prefix string
|
|
||||||
dns string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
|
||||||
switch typ {
|
|
||||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
|
||||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
|
||||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
|
||||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
|
||||||
case ChallengeLinkType:
|
|
||||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
|
||||||
case OrdersByAccountLinkType:
|
|
||||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
|
||||||
case FinalizeLinkType:
|
|
||||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLink is a helper for GetLinkExplicit
|
|
||||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
|
||||||
var (
|
|
||||||
provName string
|
|
||||||
baseURL = baseURLFromContext(ctx)
|
|
||||||
u = url.URL{}
|
|
||||||
)
|
|
||||||
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
|
|
||||||
provName = p.GetName()
|
|
||||||
}
|
|
||||||
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
|
|
||||||
if baseURL != nil {
|
|
||||||
u = *baseURL
|
|
||||||
}
|
|
||||||
|
|
||||||
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
|
|
||||||
|
|
||||||
// If no Scheme is set, then default to https.
|
|
||||||
if u.Scheme == "" {
|
|
||||||
u.Scheme = "https"
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no Host is set, then use the default (first DNS attr in the ca.json).
|
|
||||||
if u.Host == "" {
|
|
||||||
u.Host = l.dns
|
|
||||||
}
|
|
||||||
|
|
||||||
u.Path = l.prefix + u.Path
|
|
||||||
return u.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// LinkType captures the link type.
|
// LinkType captures the link type.
|
||||||
type LinkType int
|
type LinkType int
|
||||||
|
|
||||||
|
@ -160,8 +79,155 @@ func (l LinkType) String() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||||
|
switch typ {
|
||||||
|
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||||
|
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||||
|
case ChallengeLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||||
|
case OrdersByAccountLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||||
|
case FinalizeLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLinker returns a new Directory type.
|
||||||
|
func NewLinker(dns, prefix string) Linker {
|
||||||
|
_, _, err := net.SplitHostPort(dns)
|
||||||
|
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||||
|
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||||
|
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||||
|
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||||
|
// these cases, then the input dns is not changed.
|
||||||
|
lastIndex := strings.LastIndex(dns, ":")
|
||||||
|
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||||
|
if ip := net.ParseIP(hostPart); ip != nil {
|
||||||
|
dns = "[" + hostPart + "]:" + portPart
|
||||||
|
} else if ip := net.ParseIP(dns); ip != nil {
|
||||||
|
dns = "[" + dns + "]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &linker{prefix: prefix, dns: dns}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Linker interface for generating links for ACME resources.
|
||||||
|
type Linker interface {
|
||||||
|
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||||
|
Middleware(http.Handler) http.Handler
|
||||||
|
LinkOrder(ctx context.Context, o *Order)
|
||||||
|
LinkAccount(ctx context.Context, o *Account)
|
||||||
|
LinkChallenge(ctx context.Context, o *Challenge, azID string)
|
||||||
|
LinkAuthorization(ctx context.Context, o *Authorization)
|
||||||
|
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type linkerKey struct{}
|
||||||
|
|
||||||
|
// NewLinkerContext adds the given linker to the context.
|
||||||
|
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
|
||||||
|
return context.WithValue(ctx, linkerKey{}, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkerFromContext returns the current linker from the given context.
|
||||||
|
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
|
||||||
|
v, ok = ctx.Value(linkerKey{}).(Linker)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustLinkerFromContext returns the current linker from the given context. It
|
||||||
|
// will panic if it's not in the context.
|
||||||
|
func MustLinkerFromContext(ctx context.Context) Linker {
|
||||||
|
if v, ok := LinkerFromContext(ctx); !ok {
|
||||||
|
panic("acme linker is not the context")
|
||||||
|
} else {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type baseURLKey struct{}
|
||||||
|
|
||||||
|
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
|
||||||
|
var u *url.URL
|
||||||
|
if r.Host != "" {
|
||||||
|
u = &url.URL{Scheme: "https", Host: r.Host}
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, baseURLKey{}, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||||
|
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// linker generates ACME links.
|
||||||
|
type linker struct {
|
||||||
|
prefix string
|
||||||
|
dns string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware gets the provisioner and current url from the request and sets
|
||||||
|
// them in the context so we can use the linker to create ACME links.
|
||||||
|
func (l *linker) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Add base url to the context.
|
||||||
|
ctx := newBaseURLContext(r.Context(), r)
|
||||||
|
|
||||||
|
// Add provisioner to the context.
|
||||||
|
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||||
|
name, err := url.PathUnescape(nameEscaped)
|
||||||
|
if err != nil {
|
||||||
|
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
|
||||||
|
if err != nil {
|
||||||
|
render.Error(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
acmeProv, ok := p.(*provisioner.ACME)
|
||||||
|
if !ok {
|
||||||
|
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLink is a helper for GetLinkExplicit.
|
||||||
|
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||||
|
var name string
|
||||||
|
if p, ok := ProvisionerFromContext(ctx); ok {
|
||||||
|
name = p.GetName()
|
||||||
|
}
|
||||||
|
|
||||||
|
var u url.URL
|
||||||
|
if baseURL := baseURLFromContext(ctx); baseURL != nil {
|
||||||
|
u = *baseURL
|
||||||
|
}
|
||||||
|
if u.Scheme == "" {
|
||||||
|
u.Scheme = "https"
|
||||||
|
}
|
||||||
|
if u.Host == "" {
|
||||||
|
u.Host = l.dns
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
// LinkOrder sets the ACME links required by an ACME order.
|
// LinkOrder sets the ACME links required by an ACME order.
|
||||||
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
func (l *linker) LinkOrder(ctx context.Context, o *Order) {
|
||||||
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
|
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
|
||||||
for i, azID := range o.AuthorizationIDs {
|
for i, azID := range o.AuthorizationIDs {
|
||||||
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
|
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
|
||||||
|
@ -173,17 +239,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinkAccount sets the ACME links required by an ACME account.
|
// LinkAccount sets the ACME links required by an ACME account.
|
||||||
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
|
func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
|
||||||
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
|
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinkChallenge sets the ACME links required by an ACME challenge.
|
// LinkChallenge sets the ACME links required by an ACME challenge.
|
||||||
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
|
func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
|
||||||
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
|
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinkAuthorization sets the ACME links required by an ACME authorization.
|
// LinkAuthorization sets the ACME links required by an ACME authorization.
|
||||||
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
|
func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
|
||||||
for _, ch := range az.Challenges {
|
for _, ch := range az.Challenges {
|
||||||
l.LinkChallenge(ctx, ch, az.ID)
|
l.LinkChallenge(ctx, ch, az.ID)
|
||||||
}
|
}
|
|
@ -1,21 +1,38 @@
|
||||||
package api
|
package acme
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
func mockProvisioner(t *testing.T) Provisioner {
|
||||||
dns := "ca.smallstep.com"
|
t.Helper()
|
||||||
prefix := "acme"
|
var defaultDisableRenewal = false
|
||||||
linker := NewLinker(dns, prefix)
|
|
||||||
|
|
||||||
getPath := linker.GetUnescapedPathSuffix
|
// Initialize provisioners
|
||||||
|
p := &provisioner.ACME{
|
||||||
|
Type: "ACME",
|
||||||
|
Name: "test@acme-<test>provisioner.com",
|
||||||
|
}
|
||||||
|
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
|
||||||
|
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||||
|
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
|
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
}}); err != nil {
|
||||||
|
fmt.Printf("%v", err)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUnescapedPathSuffix(t *testing.T) {
|
||||||
|
getPath := GetUnescapedPathSuffix
|
||||||
|
|
||||||
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
|
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
|
||||||
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
|
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
|
||||||
|
@ -32,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLinker_DNS(t *testing.T) {
|
func TestLinker_DNS(t *testing.T) {
|
||||||
prov := newProv()
|
prov := mockProvisioner(t)
|
||||||
escProvName := url.PathEscape(prov.GetName())
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := NewProvisionerContext(context.Background(), prov)
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
dns string
|
dns string
|
||||||
|
@ -117,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) {
|
||||||
linker := NewLinker(dns, prefix)
|
linker := NewLinker(dns, prefix)
|
||||||
id := "1234"
|
id := "1234"
|
||||||
|
|
||||||
prov := newProv()
|
prov := mockProvisioner(t)
|
||||||
escProvName := url.PathEscape(prov.GetName())
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
ctx := NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||||
|
|
||||||
// No provisioner and no BaseURL from request
|
// No provisioner and no BaseURL from request
|
||||||
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
|
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
|
||||||
// Provisioner: yes, BaseURL: no
|
// Provisioner: yes, BaseURL: no
|
||||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||||
|
|
||||||
// Provisioner: no, BaseURL: yes
|
// Provisioner: no, BaseURL: yes
|
||||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||||
|
|
||||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||||
|
@ -163,37 +180,37 @@ func TestLinker_GetLink(t *testing.T) {
|
||||||
|
|
||||||
func TestLinker_LinkOrder(t *testing.T) {
|
func TestLinker_LinkOrder(t *testing.T) {
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
prov := newProv()
|
prov := mockProvisioner(t)
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
ctx := NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||||
|
|
||||||
oid := "orderID"
|
oid := "orderID"
|
||||||
certID := "certID"
|
certID := "certID"
|
||||||
linkerPrefix := "acme"
|
linkerPrefix := "acme"
|
||||||
l := NewLinker("dns", linkerPrefix)
|
l := NewLinker("dns", linkerPrefix)
|
||||||
type test struct {
|
type test struct {
|
||||||
o *acme.Order
|
o *Order
|
||||||
validate func(o *acme.Order)
|
validate func(o *Order)
|
||||||
}
|
}
|
||||||
var tests = map[string]test{
|
var tests = map[string]test{
|
||||||
"no-authz-and-no-cert": {
|
"no-authz-and-no-cert": {
|
||||||
o: &acme.Order{
|
o: &Order{
|
||||||
ID: oid,
|
ID: oid,
|
||||||
},
|
},
|
||||||
validate: func(o *acme.Order) {
|
validate: func(o *Order) {
|
||||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||||
assert.Equals(t, o.AuthorizationURLs, []string{})
|
assert.Equals(t, o.AuthorizationURLs, []string{})
|
||||||
assert.Equals(t, o.CertificateURL, "")
|
assert.Equals(t, o.CertificateURL, "")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"one-authz-and-cert": {
|
"one-authz-and-cert": {
|
||||||
o: &acme.Order{
|
o: &Order{
|
||||||
ID: oid,
|
ID: oid,
|
||||||
CertificateID: certID,
|
CertificateID: certID,
|
||||||
AuthorizationIDs: []string{"foo"},
|
AuthorizationIDs: []string{"foo"},
|
||||||
},
|
},
|
||||||
validate: func(o *acme.Order) {
|
validate: func(o *Order) {
|
||||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||||
|
@ -202,12 +219,12 @@ func TestLinker_LinkOrder(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"many-authz": {
|
"many-authz": {
|
||||||
o: &acme.Order{
|
o: &Order{
|
||||||
ID: oid,
|
ID: oid,
|
||||||
CertificateID: certID,
|
CertificateID: certID,
|
||||||
AuthorizationIDs: []string{"foo", "bar", "zap"},
|
AuthorizationIDs: []string{"foo", "bar", "zap"},
|
||||||
},
|
},
|
||||||
validate: func(o *acme.Order) {
|
validate: func(o *Order) {
|
||||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||||
|
@ -228,24 +245,24 @@ func TestLinker_LinkOrder(t *testing.T) {
|
||||||
|
|
||||||
func TestLinker_LinkAccount(t *testing.T) {
|
func TestLinker_LinkAccount(t *testing.T) {
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
prov := newProv()
|
prov := mockProvisioner(t)
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
ctx := NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||||
|
|
||||||
accID := "accountID"
|
accID := "accountID"
|
||||||
linkerPrefix := "acme"
|
linkerPrefix := "acme"
|
||||||
l := NewLinker("dns", linkerPrefix)
|
l := NewLinker("dns", linkerPrefix)
|
||||||
type test struct {
|
type test struct {
|
||||||
a *acme.Account
|
a *Account
|
||||||
validate func(o *acme.Account)
|
validate func(o *Account)
|
||||||
}
|
}
|
||||||
var tests = map[string]test{
|
var tests = map[string]test{
|
||||||
"ok": {
|
"ok": {
|
||||||
a: &acme.Account{
|
a: &Account{
|
||||||
ID: accID,
|
ID: accID,
|
||||||
},
|
},
|
||||||
validate: func(a *acme.Account) {
|
validate: func(a *Account) {
|
||||||
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
|
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -260,25 +277,25 @@ func TestLinker_LinkAccount(t *testing.T) {
|
||||||
|
|
||||||
func TestLinker_LinkChallenge(t *testing.T) {
|
func TestLinker_LinkChallenge(t *testing.T) {
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
prov := newProv()
|
prov := mockProvisioner(t)
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
ctx := NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||||
|
|
||||||
chID := "chID"
|
chID := "chID"
|
||||||
azID := "azID"
|
azID := "azID"
|
||||||
linkerPrefix := "acme"
|
linkerPrefix := "acme"
|
||||||
l := NewLinker("dns", linkerPrefix)
|
l := NewLinker("dns", linkerPrefix)
|
||||||
type test struct {
|
type test struct {
|
||||||
ch *acme.Challenge
|
ch *Challenge
|
||||||
validate func(o *acme.Challenge)
|
validate func(o *Challenge)
|
||||||
}
|
}
|
||||||
var tests = map[string]test{
|
var tests = map[string]test{
|
||||||
"ok": {
|
"ok": {
|
||||||
ch: &acme.Challenge{
|
ch: &Challenge{
|
||||||
ID: chID,
|
ID: chID,
|
||||||
},
|
},
|
||||||
validate: func(ch *acme.Challenge) {
|
validate: func(ch *Challenge) {
|
||||||
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
|
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -293,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) {
|
||||||
|
|
||||||
func TestLinker_LinkAuthorization(t *testing.T) {
|
func TestLinker_LinkAuthorization(t *testing.T) {
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
prov := newProv()
|
prov := mockProvisioner(t)
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
ctx := NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||||
|
|
||||||
chID0 := "chID-0"
|
chID0 := "chID-0"
|
||||||
chID1 := "chID-1"
|
chID1 := "chID-1"
|
||||||
|
@ -305,20 +322,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
||||||
linkerPrefix := "acme"
|
linkerPrefix := "acme"
|
||||||
l := NewLinker("dns", linkerPrefix)
|
l := NewLinker("dns", linkerPrefix)
|
||||||
type test struct {
|
type test struct {
|
||||||
az *acme.Authorization
|
az *Authorization
|
||||||
validate func(o *acme.Authorization)
|
validate func(o *Authorization)
|
||||||
}
|
}
|
||||||
var tests = map[string]test{
|
var tests = map[string]test{
|
||||||
"ok": {
|
"ok": {
|
||||||
az: &acme.Authorization{
|
az: &Authorization{
|
||||||
ID: azID,
|
ID: azID,
|
||||||
Challenges: []*acme.Challenge{
|
Challenges: []*Challenge{
|
||||||
{ID: chID0},
|
{ID: chID0},
|
||||||
{ID: chID1},
|
{ID: chID1},
|
||||||
{ID: chID2},
|
{ID: chID2},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(az *acme.Authorization) {
|
validate: func(az *Authorization) {
|
||||||
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
|
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
|
||||||
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
|
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
|
||||||
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
|
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
|
||||||
|
@ -335,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
||||||
|
|
||||||
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
|
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
prov := newProv()
|
prov := mockProvisioner(t)
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
ctx := NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||||
|
|
||||||
linkerPrefix := "acme"
|
linkerPrefix := "acme"
|
||||||
l := NewLinker("dns", linkerPrefix)
|
l := NewLinker("dns", linkerPrefix)
|
107
api/api.go
107
api/api.go
|
@ -35,7 +35,6 @@ type Authority interface {
|
||||||
SSHAuthority
|
SSHAuthority
|
||||||
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
|
||||||
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
|
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||||
GetTLSOptions() *config.TLSOptions
|
GetTLSOptions() *config.TLSOptions
|
||||||
Root(shasum string) (*x509.Certificate, error)
|
Root(shasum string) (*x509.Certificate, error)
|
||||||
|
@ -52,6 +51,11 @@ type Authority interface {
|
||||||
Version() authority.Version
|
Version() authority.Version
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mustAuthority will be replaced on unit tests.
|
||||||
|
var mustAuthority = func(ctx context.Context) Authority {
|
||||||
|
return authority.MustFromContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// TimeDuration is an alias of provisioner.TimeDuration
|
// TimeDuration is an alias of provisioner.TimeDuration
|
||||||
type TimeDuration = provisioner.TimeDuration
|
type TimeDuration = provisioner.TimeDuration
|
||||||
|
|
||||||
|
@ -243,48 +247,53 @@ type caHandler struct {
|
||||||
Authority Authority
|
Authority Authority
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new RouterHandler with the CA endpoints.
|
// Route configures the http request router.
|
||||||
func New(auth Authority) RouterHandler {
|
func (h *caHandler) Route(r Router) {
|
||||||
return &caHandler{
|
Route(r)
|
||||||
Authority: auth,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *caHandler) Route(r Router) {
|
// New creates a new RouterHandler with the CA endpoints.
|
||||||
r.MethodFunc("GET", "/version", h.Version)
|
//
|
||||||
r.MethodFunc("GET", "/health", h.Health)
|
// Deprecated: Use api.Route(r Router)
|
||||||
r.MethodFunc("GET", "/root/{sha}", h.Root)
|
func New(auth Authority) RouterHandler {
|
||||||
r.MethodFunc("POST", "/sign", h.Sign)
|
return &caHandler{}
|
||||||
r.MethodFunc("POST", "/renew", h.Renew)
|
}
|
||||||
r.MethodFunc("POST", "/rekey", h.Rekey)
|
|
||||||
r.MethodFunc("POST", "/revoke", h.Revoke)
|
func Route(r Router) {
|
||||||
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
r.MethodFunc("GET", "/version", Version)
|
||||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
r.MethodFunc("GET", "/health", Health)
|
||||||
r.MethodFunc("GET", "/roots", h.Roots)
|
r.MethodFunc("GET", "/root/{sha}", Root)
|
||||||
r.MethodFunc("GET", "/roots.pem", h.RootsPEM)
|
r.MethodFunc("POST", "/sign", Sign)
|
||||||
r.MethodFunc("GET", "/federation", h.Federation)
|
r.MethodFunc("POST", "/renew", Renew)
|
||||||
|
r.MethodFunc("POST", "/rekey", Rekey)
|
||||||
|
r.MethodFunc("POST", "/revoke", Revoke)
|
||||||
|
r.MethodFunc("GET", "/provisioners", Provisioners)
|
||||||
|
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
|
||||||
|
r.MethodFunc("GET", "/roots", Roots)
|
||||||
|
r.MethodFunc("GET", "/roots.pem", RootsPEM)
|
||||||
|
r.MethodFunc("GET", "/federation", Federation)
|
||||||
// SSH CA
|
// SSH CA
|
||||||
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
|
r.MethodFunc("POST", "/ssh/sign", SSHSign)
|
||||||
r.MethodFunc("POST", "/ssh/renew", h.SSHRenew)
|
r.MethodFunc("POST", "/ssh/renew", SSHRenew)
|
||||||
r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke)
|
r.MethodFunc("POST", "/ssh/revoke", SSHRevoke)
|
||||||
r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey)
|
r.MethodFunc("POST", "/ssh/rekey", SSHRekey)
|
||||||
r.MethodFunc("GET", "/ssh/roots", h.SSHRoots)
|
r.MethodFunc("GET", "/ssh/roots", SSHRoots)
|
||||||
r.MethodFunc("GET", "/ssh/federation", h.SSHFederation)
|
r.MethodFunc("GET", "/ssh/federation", SSHFederation)
|
||||||
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
|
r.MethodFunc("POST", "/ssh/config", SSHConfig)
|
||||||
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
|
r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig)
|
||||||
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
|
r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost)
|
||||||
r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts)
|
r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts)
|
||||||
r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion)
|
r.MethodFunc("POST", "/ssh/bastion", SSHBastion)
|
||||||
|
|
||||||
// For compatibility with old code:
|
// For compatibility with old code:
|
||||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
r.MethodFunc("POST", "/re-sign", Renew)
|
||||||
r.MethodFunc("POST", "/sign-ssh", h.SSHSign)
|
r.MethodFunc("POST", "/sign-ssh", SSHSign)
|
||||||
r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts)
|
r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Version is an HTTP handler that returns the version of the server.
|
// Version is an HTTP handler that returns the version of the server.
|
||||||
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
func Version(w http.ResponseWriter, r *http.Request) {
|
||||||
v := h.Authority.Version()
|
v := mustAuthority(r.Context()).Version()
|
||||||
render.JSON(w, VersionResponse{
|
render.JSON(w, VersionResponse{
|
||||||
Version: v.Version,
|
Version: v.Version,
|
||||||
RequireClientAuthentication: v.RequireClientAuthentication,
|
RequireClientAuthentication: v.RequireClientAuthentication,
|
||||||
|
@ -292,17 +301,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Health is an HTTP handler that returns the status of the server.
|
// Health is an HTTP handler that returns the status of the server.
|
||||||
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
func Health(w http.ResponseWriter, r *http.Request) {
|
||||||
render.JSON(w, HealthResponse{Status: "ok"})
|
render.JSON(w, HealthResponse{Status: "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
||||||
// certificate for the given SHA256.
|
// certificate for the given SHA256.
|
||||||
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
func Root(w http.ResponseWriter, r *http.Request) {
|
||||||
sha := chi.URLParam(r, "sha")
|
sha := chi.URLParam(r, "sha")
|
||||||
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
|
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
|
||||||
// Load root certificate with the
|
// Load root certificate with the
|
||||||
cert, err := h.Authority.Root(sum)
|
cert, err := mustAuthority(r.Context()).Root(sum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||||
return
|
return
|
||||||
|
@ -320,18 +329,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Provisioners returns the list of provisioners configured in the authority.
|
// Provisioners returns the list of provisioners configured in the authority.
|
||||||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
func Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||||
cursor, limit, err := ParseCursor(r)
|
cursor, limit, err := ParseCursor(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p, next, err := h.Authority.GetProvisioners(cursor, limit)
|
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
render.JSON(w, &ProvisionersResponse{
|
render.JSON(w, &ProvisionersResponse{
|
||||||
Provisioners: p,
|
Provisioners: p,
|
||||||
NextCursor: next,
|
NextCursor: next,
|
||||||
|
@ -339,19 +349,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvisionerKey returns the encrypted key of a provisioner by it's key id.
|
// ProvisionerKey returns the encrypted key of a provisioner by it's key id.
|
||||||
func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||||
kid := chi.URLParam(r, "kid")
|
kid := chi.URLParam(r, "kid")
|
||||||
key, err := h.Authority.GetEncryptedKey(kid)
|
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.NotFoundErr(err))
|
render.Error(w, errs.NotFoundErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
render.JSON(w, &ProvisionerKeyResponse{key})
|
render.JSON(w, &ProvisionerKeyResponse{key})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Roots returns all the root certificates for the CA.
|
// Roots returns all the root certificates for the CA.
|
||||||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
func Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
roots, err := h.Authority.GetRoots()
|
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
|
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
|
||||||
return
|
return
|
||||||
|
@ -368,8 +379,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RootsPEM returns all the root certificates for the CA in PEM format.
|
// RootsPEM returns all the root certificates for the CA in PEM format.
|
||||||
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
|
func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||||
roots, err := h.Authority.GetRoots()
|
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -391,8 +402,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Federation returns all the public certificates in the federation.
|
// Federation returns all the public certificates in the federation.
|
||||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
func Federation(w http.ResponseWriter, r *http.Request) {
|
||||||
federated, err := h.Authority.GetFederation()
|
federated, err := mustAuthority(r.Context()).GetFederation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||||
return
|
return
|
||||||
|
|
|
@ -171,10 +171,21 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
|
||||||
return csr
|
return csr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mockMustAuthority(t *testing.T, a Authority) {
|
||||||
|
t.Helper()
|
||||||
|
fn := mustAuthority
|
||||||
|
t.Cleanup(func() {
|
||||||
|
mustAuthority = fn
|
||||||
|
})
|
||||||
|
mustAuthority = func(ctx context.Context) Authority {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type mockAuthority struct {
|
type mockAuthority struct {
|
||||||
ret1, ret2 interface{}
|
ret1, ret2 interface{}
|
||||||
err error
|
err error
|
||||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||||
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
|
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||||
getTLSOptions func() *authority.TLSOptions
|
getTLSOptions func() *authority.TLSOptions
|
||||||
root func(shasum string) (*x509.Certificate, error)
|
root func(shasum string) (*x509.Certificate, error)
|
||||||
|
@ -203,12 +214,8 @@ type mockAuthority struct {
|
||||||
|
|
||||||
// TODO: remove once Authorize is deprecated.
|
// TODO: remove once Authorize is deprecated.
|
||||||
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
return m.AuthorizeSign(ott)
|
if m.authorize != nil {
|
||||||
}
|
return m.authorize(ctx, ott)
|
||||||
|
|
||||||
func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
|
||||||
if m.authorizeSign != nil {
|
|
||||||
return m.authorizeSign(ott)
|
|
||||||
}
|
}
|
||||||
return m.ret1.([]provisioner.SignOption), m.err
|
return m.ret1.([]provisioner.SignOption), m.err
|
||||||
}
|
}
|
||||||
|
@ -789,11 +796,10 @@ func Test_caHandler_Route(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Health(t *testing.T) {
|
func Test_Health(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/health", nil)
|
req := httptest.NewRequest("GET", "http://example.com/health", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h := New(&mockAuthority{}).(*caHandler)
|
Health(w, req)
|
||||||
h.Health(w, req)
|
|
||||||
|
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
if res.StatusCode != 200 {
|
if res.StatusCode != 200 {
|
||||||
|
@ -811,7 +817,7 @@ func Test_caHandler_Health(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Root(t *testing.T) {
|
func Test_Root(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
root *x509.Certificate
|
root *x509.Certificate
|
||||||
|
@ -832,9 +838,9 @@ func Test_caHandler_Root(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
|
mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Root(w, req)
|
Root(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -855,7 +861,7 @@ func Test_caHandler_Root(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Sign(t *testing.T) {
|
func Test_Sign(t *testing.T) {
|
||||||
csr := parseCertificateRequest(csrPEM)
|
csr := parseCertificateRequest(csrPEM)
|
||||||
valid, err := json.Marshal(SignRequest{
|
valid, err := json.Marshal(SignRequest{
|
||||||
CsrPEM: CertificateRequest{csr},
|
CsrPEM: CertificateRequest{csr},
|
||||||
|
@ -896,18 +902,18 @@ func Test_caHandler_Sign(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
||||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
return tt.certAttrOpts, tt.autherr
|
return tt.certAttrOpts, tt.autherr
|
||||||
},
|
},
|
||||||
getTLSOptions: func() *authority.TLSOptions {
|
getTLSOptions: func() *authority.TLSOptions {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
|
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Sign(logging.NewResponseLogger(w), req)
|
Sign(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -928,7 +934,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Renew(t *testing.T) {
|
func Test_Renew(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
}
|
}
|
||||||
|
@ -1018,7 +1024,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||||
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
|
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||||
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
|
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
|
||||||
|
@ -1039,12 +1045,12 @@ func Test_caHandler_Renew(t *testing.T) {
|
||||||
getTLSOptions: func() *authority.TLSOptions {
|
getTLSOptions: func() *authority.TLSOptions {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
||||||
req.TLS = tt.tls
|
req.TLS = tt.tls
|
||||||
req.Header = tt.header
|
req.Header = tt.header
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Renew(logging.NewResponseLogger(w), req)
|
Renew(logging.NewResponseLogger(w), req)
|
||||||
|
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
@ -1073,7 +1079,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Rekey(t *testing.T) {
|
func Test_Rekey(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
}
|
}
|
||||||
|
@ -1104,16 +1110,16 @@ func Test_caHandler_Rekey(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||||
getTLSOptions: func() *authority.TLSOptions {
|
getTLSOptions: func() *authority.TLSOptions {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
|
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
|
||||||
req.TLS = tt.tls
|
req.TLS = tt.tls
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Rekey(logging.NewResponseLogger(w), req)
|
Rekey(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -1134,7 +1140,7 @@ func Test_caHandler_Rekey(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Provisioners(t *testing.T) {
|
func Test_Provisioners(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Authority Authority
|
Authority Authority
|
||||||
}
|
}
|
||||||
|
@ -1200,10 +1206,8 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := &caHandler{
|
mockMustAuthority(t, tt.fields.Authority)
|
||||||
Authority: tt.fields.Authority,
|
Provisioners(tt.args.w, tt.args.r)
|
||||||
}
|
|
||||||
h.Provisioners(tt.args.w, tt.args.r)
|
|
||||||
|
|
||||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||||
res := rec.Result()
|
res := rec.Result()
|
||||||
|
@ -1238,7 +1242,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_ProvisionerKey(t *testing.T) {
|
func Test_ProvisionerKey(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Authority Authority
|
Authority Authority
|
||||||
}
|
}
|
||||||
|
@ -1270,10 +1274,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := &caHandler{
|
mockMustAuthority(t, tt.fields.Authority)
|
||||||
Authority: tt.fields.Authority,
|
ProvisionerKey(tt.args.w, tt.args.r)
|
||||||
}
|
|
||||||
h.ProvisionerKey(tt.args.w, tt.args.r)
|
|
||||||
|
|
||||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||||
res := rec.Result()
|
res := rec.Result()
|
||||||
|
@ -1298,7 +1300,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Roots(t *testing.T) {
|
func Test_Roots(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
}
|
}
|
||||||
|
@ -1319,11 +1321,11 @@ func Test_caHandler_Roots(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||||
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
||||||
req.TLS = tt.tls
|
req.TLS = tt.tls
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Roots(w, req)
|
Roots(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -1360,10 +1362,10 @@ func Test_caHandler_RootsPEM(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler)
|
mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err})
|
||||||
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
|
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.RootsPEM(w, req)
|
RootsPEM(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -1384,7 +1386,7 @@ func Test_caHandler_RootsPEM(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Federation(t *testing.T) {
|
func Test_Federation(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
}
|
}
|
||||||
|
@ -1405,11 +1407,11 @@ func Test_caHandler_Federation(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||||
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
|
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
|
||||||
req.TLS = tt.tls
|
req.TLS = tt.tls
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Federation(w, req)
|
Federation(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
|
|
@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
||||||
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
func Rekey(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||||
render.Error(w, errs.BadRequest("missing client certificate"))
|
render.Error(w, errs.BadRequest("missing client certificate"))
|
||||||
return
|
return
|
||||||
|
@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
a := mustAuthority(r.Context())
|
||||||
|
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||||
return
|
return
|
||||||
|
@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
||||||
ServerPEM: certChainPEM[0],
|
ServerPEM: certChainPEM[0],
|
||||||
CaPEM: caPEM,
|
CaPEM: caPEM,
|
||||||
CertChainPEM: certChainPEM,
|
CertChainPEM: certChainPEM,
|
||||||
TLSOptions: h.Authority.GetTLSOptions(),
|
TLSOptions: a.GetTLSOptions(),
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
14
api/renew.go
14
api/renew.go
|
@ -16,14 +16,15 @@ const (
|
||||||
|
|
||||||
// Renew uses the information of certificate in the TLS connection to create a
|
// Renew uses the information of certificate in the TLS connection to create a
|
||||||
// new one.
|
// new one.
|
||||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
func Renew(w http.ResponseWriter, r *http.Request) {
|
||||||
cert, err := h.getPeerCertificate(r)
|
cert, err := getPeerCertificate(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certChain, err := h.Authority.Renew(cert)
|
a := mustAuthority(r.Context())
|
||||||
|
certChain, err := a.Renew(cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||||
return
|
return
|
||||||
|
@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||||
ServerPEM: certChainPEM[0],
|
ServerPEM: certChainPEM[0],
|
||||||
CaPEM: caPEM,
|
CaPEM: caPEM,
|
||||||
CertChainPEM: certChainPEM,
|
CertChainPEM: certChainPEM,
|
||||||
TLSOptions: h.Authority.GetTLSOptions(),
|
TLSOptions: a.GetTLSOptions(),
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||||
return r.TLS.PeerCertificates[0], nil
|
return r.TLS.PeerCertificates[0], nil
|
||||||
}
|
}
|
||||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||||
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
|
ctx := r.Context()
|
||||||
|
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errs.BadRequest("missing client certificate")
|
return nil, errs.BadRequest("missing client certificate")
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"golang.org/x/crypto/ocsp"
|
"golang.org/x/crypto/ocsp"
|
||||||
|
@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) {
|
||||||
// NOTE: currently only Passive revocation is supported.
|
// NOTE: currently only Passive revocation is supported.
|
||||||
//
|
//
|
||||||
// TODO: Add CRL and OCSP support.
|
// TODO: Add CRL and OCSP support.
|
||||||
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||||
var body RevokeRequest
|
var body RevokeRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
|
@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||||
PassiveOnly: body.Passive,
|
PassiveOnly: body.Passive,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod)
|
||||||
|
a := mustAuthority(ctx)
|
||||||
|
|
||||||
// A token indicates that we are using the api via a provisioner token,
|
// A token indicates that we are using the api via a provisioner token,
|
||||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||||
if len(body.OTT) > 0 {
|
if len(body.OTT) > 0 {
|
||||||
logOtt(w, body.OTT)
|
logOtt(w, body.OTT)
|
||||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||||
render.Error(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||||
opts.MTLS = true
|
opts.MTLS = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
if err := a.Revoke(ctx, opts); err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
||||||
input: string(input),
|
input: string(input),
|
||||||
statusCode: http.StatusOK,
|
statusCode: http.StatusOK,
|
||||||
auth: &mockAuthority{
|
auth: &mockAuthority{
|
||||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||||
|
@ -152,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
||||||
statusCode: http.StatusOK,
|
statusCode: http.StatusOK,
|
||||||
tls: cs,
|
tls: cs,
|
||||||
auth: &mockAuthority{
|
auth: &mockAuthority{
|
||||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
revoke: func(ctx context.Context, ri *authority.RevokeOptions) error {
|
revoke: func(ctx context.Context, ri *authority.RevokeOptions) error {
|
||||||
|
@ -187,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
||||||
input: string(input),
|
input: string(input),
|
||||||
statusCode: http.StatusInternalServerError,
|
statusCode: http.StatusInternalServerError,
|
||||||
auth: &mockAuthority{
|
auth: &mockAuthority{
|
||||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||||
|
@ -209,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
||||||
input: string(input),
|
input: string(input),
|
||||||
statusCode: http.StatusForbidden,
|
statusCode: http.StatusForbidden,
|
||||||
auth: &mockAuthority{
|
auth: &mockAuthority{
|
||||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||||
|
@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) {
|
||||||
for name, _tc := range tests {
|
for name, _tc := range tests {
|
||||||
tc := _tc(t)
|
tc := _tc(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := New(tc.auth).(*caHandler)
|
mockMustAuthority(t, tc.auth)
|
||||||
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
|
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
|
||||||
if tc.tls != nil {
|
if tc.tls != nil {
|
||||||
req.TLS = tc.tls
|
req.TLS = tc.tls
|
||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Revoke(logging.NewResponseLogger(w), req)
|
Revoke(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
12
api/sign.go
12
api/sign.go
|
@ -49,7 +49,7 @@ type SignResponse struct {
|
||||||
// Sign is an HTTP handler that reads a certificate request and an
|
// Sign is an HTTP handler that reads a certificate request and an
|
||||||
// one-time-token (ott) from the body and creates a new certificate with the
|
// one-time-token (ott) from the body and creates a new certificate with the
|
||||||
// information in the certificate request.
|
// information in the certificate request.
|
||||||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
func Sign(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SignRequest
|
var body SignRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
|
@ -68,13 +68,17 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||||
TemplateData: body.TemplateData,
|
TemplateData: body.TemplateData,
|
||||||
}
|
}
|
||||||
|
|
||||||
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
ctx := r.Context()
|
||||||
|
a := mustAuthority(ctx)
|
||||||
|
|
||||||
|
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||||
|
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
|
||||||
return
|
return
|
||||||
|
@ -89,6 +93,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||||
ServerPEM: certChainPEM[0],
|
ServerPEM: certChainPEM[0],
|
||||||
CaPEM: caPEM,
|
CaPEM: caPEM,
|
||||||
CertChainPEM: certChainPEM,
|
CertChainPEM: certChainPEM,
|
||||||
TLSOptions: h.Authority.GetTLSOptions(),
|
TLSOptions: a.GetTLSOptions(),
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
44
api/ssh.go
44
api/ssh.go
|
@ -250,7 +250,7 @@ type SSHBastionResponse struct {
|
||||||
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||||
// (ott) from the body and creates a new SSH certificate with the information in
|
// (ott) from the body and creates a new SSH certificate with the information in
|
||||||
// the request.
|
// the request.
|
||||||
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHSignRequest
|
var body SSHSignRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
|
@ -288,13 +288,15 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
|
||||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
|
||||||
|
a := mustAuthority(ctx)
|
||||||
|
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
|
cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||||
return
|
return
|
||||||
|
@ -302,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
var addUserCertificate *SSHCertificate
|
var addUserCertificate *SSHCertificate
|
||||||
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
||||||
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||||
return
|
return
|
||||||
|
@ -315,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
if cr := body.IdentityCSR.CertificateRequest; cr != nil {
|
if cr := body.IdentityCSR.CertificateRequest; cr != nil {
|
||||||
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
|
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
|
||||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
|
@ -327,7 +329,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
|
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
|
||||||
})
|
})
|
||||||
|
|
||||||
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||||
return
|
return
|
||||||
|
@ -344,8 +346,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
|
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
|
||||||
// certificates.
|
// certificates.
|
||||||
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||||
keys, err := h.Authority.GetSSHRoots(r.Context())
|
ctx := r.Context()
|
||||||
|
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -369,8 +372,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
||||||
// for user and host certificates.
|
// for user and host certificates.
|
||||||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||||
keys, err := h.Authority.GetSSHFederation(r.Context())
|
ctx := r.Context()
|
||||||
|
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -394,7 +398,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
||||||
// and servers.
|
// and servers.
|
||||||
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHConfigRequest
|
var body SSHConfigRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
|
@ -405,7 +409,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
|
ctx := r.Context()
|
||||||
|
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -426,7 +431,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||||
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHCheckPrincipalRequest
|
var body SSHCheckPrincipalRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
|
@ -437,7 +442,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
|
ctx := r.Context()
|
||||||
|
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -448,13 +454,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
||||||
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||||
var cert *x509.Certificate
|
var cert *x509.Certificate
|
||||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||||
cert = r.TLS.PeerCertificates[0]
|
cert = r.TLS.PeerCertificates[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
|
ctx := r.Context()
|
||||||
|
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -465,7 +472,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHBastion provides returns the bastion configured if any.
|
// SSHBastion provides returns the bastion configured if any.
|
||||||
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
func SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHBastionRequest
|
var body SSHBastionRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
|
@ -476,7 +483,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
|
ctx := r.Context()
|
||||||
|
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -39,7 +39,7 @@ type SSHRekeyResponse struct {
|
||||||
// SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token
|
// SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token
|
||||||
// (ott) from the body and creates a new SSH certificate with the information in
|
// (ott) from the body and creates a new SSH certificate with the information in
|
||||||
// the request.
|
// the request.
|
||||||
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHRekeyRequest
|
var body SSHRekeyRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
|
@ -59,7 +59,9 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
|
||||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
|
||||||
|
a := mustAuthority(ctx)
|
||||||
|
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
|
@ -70,7 +72,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||||
return
|
return
|
||||||
|
@ -80,7 +82,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||||
|
|
||||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||||
return
|
return
|
||||||
|
|
|
@ -37,7 +37,7 @@ type SSHRenewResponse struct {
|
||||||
// SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token
|
// SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token
|
||||||
// (ott) from the body and creates a new SSH certificate with the information in
|
// (ott) from the body and creates a new SSH certificate with the information in
|
||||||
// the request.
|
// the request.
|
||||||
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHRenewRequest
|
var body SSHRenewRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
|
@ -51,7 +51,8 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
|
||||||
_, err := h.Authority.Authorize(ctx, body.OTT)
|
a := mustAuthority(ctx)
|
||||||
|
_, err := a.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
|
@ -62,7 +63,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
|
newCert, err := a.RenewSSH(ctx, oldCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||||
return
|
return
|
||||||
|
@ -72,7 +73,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||||
|
|
||||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||||
return
|
return
|
||||||
|
@ -85,7 +86,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
|
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
|
||||||
func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -105,7 +106,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte
|
||||||
cert.NotAfter = notAfter
|
cert.NotAfter = notAfter
|
||||||
}
|
}
|
||||||
|
|
||||||
certChain, err := h.Authority.Renew(cert)
|
certChain, err := mustAuthority(r.Context()).Renew(cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
||||||
// Revoke supports handful of different methods that revoke a Certificate.
|
// Revoke supports handful of different methods that revoke a Certificate.
|
||||||
//
|
//
|
||||||
// NOTE: currently only Passive revocation is supported.
|
// NOTE: currently only Passive revocation is supported.
|
||||||
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHRevokeRequest
|
var body SSHRevokeRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
|
@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
|
||||||
|
a := mustAuthority(ctx)
|
||||||
|
|
||||||
// A token indicates that we are using the api via a provisioner token,
|
// A token indicates that we are using the api via a provisioner token,
|
||||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||||
logOtt(w, body.OTT)
|
logOtt(w, body.OTT)
|
||||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
|
||||||
|
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||||
render.Error(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
opts.OTT = body.OTT
|
opts.OTT = body.OTT
|
||||||
|
|
||||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
if err := a.Revoke(ctx, opts); err != nil {
|
||||||
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHSign(t *testing.T) {
|
func Test_SSHSign(t *testing.T) {
|
||||||
user, err := getSignedUserCertificate()
|
user, err := getSignedUserCertificate()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
host, err := getSignedHostCertificate()
|
host, err := getSignedHostCertificate()
|
||||||
|
@ -315,8 +315,8 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||||
return []provisioner.SignOption{}, tt.authErr
|
return []provisioner.SignOption{}, tt.authErr
|
||||||
},
|
},
|
||||||
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
|
@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||||
return tt.tlsSignCerts, tt.tlsSignErr
|
return tt.tlsSignCerts, tt.tlsSignErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHSign(logging.NewResponseLogger(w), req)
|
SSHSign(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHRoots(t *testing.T) {
|
func Test_SSHRoots(t *testing.T) {
|
||||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||||
return tt.keys, tt.keysErr
|
return tt.keys, tt.keysErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
|
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHRoots(logging.NewResponseLogger(w), req)
|
SSHRoots(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHFederation(t *testing.T) {
|
func Test_SSHFederation(t *testing.T) {
|
||||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||||
return tt.keys, tt.keysErr
|
return tt.keys, tt.keysErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
|
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHFederation(logging.NewResponseLogger(w), req)
|
SSHFederation(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHConfig(t *testing.T) {
|
func Test_SSHConfig(t *testing.T) {
|
||||||
userOutput := []templates.Output{
|
userOutput := []templates.Output{
|
||||||
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
|
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
|
||||||
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
|
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
|
||||||
|
@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||||
return tt.output, tt.err
|
return tt.output, tt.err
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
|
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHConfig(logging.NewResponseLogger(w), req)
|
SSHConfig(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHCheckHost(t *testing.T) {
|
func Test_SSHCheckHost(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
req string
|
req string
|
||||||
|
@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
|
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
|
||||||
return tt.exists, tt.err
|
return tt.exists, tt.err
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
|
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHCheckHost(logging.NewResponseLogger(w), req)
|
SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHGetHosts(t *testing.T) {
|
func Test_SSHGetHosts(t *testing.T) {
|
||||||
hosts := []authority.Host{
|
hosts := []authority.Host{
|
||||||
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
|
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
|
||||||
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
|
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
|
||||||
|
@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
|
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
|
||||||
return tt.hosts, tt.err
|
return tt.hosts, tt.err
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
|
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHGetHosts(logging.NewResponseLogger(w), req)
|
SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHBastion(t *testing.T) {
|
func Test_SSHBastion(t *testing.T) {
|
||||||
bastion := &authority.Bastion{
|
bastion := &authority.Bastion{
|
||||||
Hostname: "bastion.local",
|
Hostname: "bastion.local",
|
||||||
}
|
}
|
||||||
|
@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||||
return tt.bastion, tt.bastionErr
|
return tt.bastion, tt.bastionErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
|
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHBastion(logging.NewResponseLogger(w), req)
|
SSHBastion(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
|
|
@ -33,7 +33,7 @@ type GetExternalAccountKeysResponse struct {
|
||||||
|
|
||||||
// requireEABEnabled is a middleware that ensures ACME EAB is enabled
|
// requireEABEnabled is a middleware that ensures ACME EAB is enabled
|
||||||
// before serving requests that act on ACME EAB credentials.
|
// before serving requests that act on ACME EAB credentials.
|
||||||
func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||||
|
@ -53,32 +53,33 @@ func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type acmeAdminResponderInterface interface {
|
// ACMEAdminResponder is responsible for writing ACME admin responses
|
||||||
|
type ACMEAdminResponder interface {
|
||||||
GetExternalAccountKeys(w http.ResponseWriter, r *http.Request)
|
GetExternalAccountKeys(w http.ResponseWriter, r *http.Request)
|
||||||
CreateExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
CreateExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
||||||
DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ACMEAdminResponder is responsible for writing ACME admin responses
|
// acmeAdminResponder implements ACMEAdminResponder.
|
||||||
type ACMEAdminResponder struct{}
|
type acmeAdminResponder struct{}
|
||||||
|
|
||||||
// NewACMEAdminResponder returns a new ACMEAdminResponder
|
// NewACMEAdminResponder returns a new ACMEAdminResponder
|
||||||
func NewACMEAdminResponder() *ACMEAdminResponder {
|
func NewACMEAdminResponder() ACMEAdminResponder {
|
||||||
return &ACMEAdminResponder{}
|
return &acmeAdminResponder{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
|
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
|
||||||
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
|
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
|
||||||
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
|
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
|
||||||
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,17 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
||||||
return protojson.Unmarshal(data, m)
|
return protojson.Unmarshal(data, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mockMustAuthority(t *testing.T, a adminAuthority) {
|
||||||
|
t.Helper()
|
||||||
|
fn := mustAuthority
|
||||||
|
t.Cleanup(func() {
|
||||||
|
mustAuthority = fn
|
||||||
|
})
|
||||||
|
mustAuthority = func(ctx context.Context) adminAuthority {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandler_requireEABEnabled(t *testing.T) {
|
func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
@ -117,12 +128,9 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx)
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil)
|
|
||||||
req = req.WithContext(tc.ctx)
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.requireEABEnabled(tc.next)(w, req)
|
requireEABEnabled(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
|
@ -85,10 +85,10 @@ type DeleteResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdmin returns the requested admin, or an error.
|
// GetAdmin returns the requested admin, or an error.
|
||||||
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
func GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
id := chi.URLParam(r, "id")
|
id := chi.URLParam(r, "id")
|
||||||
|
|
||||||
adm, ok := h.auth.LoadAdminByID(id)
|
adm, ok := mustAuthority(r.Context()).LoadAdminByID(id)
|
||||||
if !ok {
|
if !ok {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
|
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
|
||||||
"admin %s not found", id))
|
"admin %s not found", id))
|
||||||
|
@ -98,7 +98,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdmins returns a segment of admins associated with the authority.
|
// GetAdmins returns a segment of admins associated with the authority.
|
||||||
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
func GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||||
cursor, limit, err := api.ParseCursor(r)
|
cursor, limit, err := api.ParseCursor(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||||
|
@ -106,7 +106,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
|
admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
||||||
return
|
return
|
||||||
|
@ -118,7 +118,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAdmin creates a new admin.
|
// CreateAdmin creates a new admin.
|
||||||
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
func CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
var body CreateAdminRequest
|
var body CreateAdminRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||||
|
@ -130,7 +130,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
|
auth := mustAuthority(r.Context())
|
||||||
|
p, err := auth.LoadProvisionerByName(body.Provisioner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
||||||
return
|
return
|
||||||
|
@ -141,7 +142,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
Type: body.Type,
|
Type: body.Type,
|
||||||
}
|
}
|
||||||
// Store to authority collection.
|
// Store to authority collection.
|
||||||
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
if err := auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
|
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -150,10 +151,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAdmin deletes admin.
|
// DeleteAdmin deletes admin.
|
||||||
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
id := chi.URLParam(r, "id")
|
id := chi.URLParam(r, "id")
|
||||||
|
|
||||||
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
|
if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -162,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAdmin updates an existing admin.
|
// UpdateAdmin updates an existing admin.
|
||||||
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
var body UpdateAdminRequest
|
var body UpdateAdminRequest
|
||||||
if err := read.JSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||||
|
@ -175,8 +176,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
id := chi.URLParam(r, "id")
|
id := chi.URLParam(r, "id")
|
||||||
|
auth := mustAuthority(r.Context())
|
||||||
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
||||||
return
|
return
|
||||||
|
|
|
@ -352,14 +352,11 @@ func TestHandler_GetAdmin(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetAdmin(w, req)
|
GetAdmin(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -491,13 +488,10 @@ func TestHandler_GetAdmins(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
req := tc.req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetAdmins(w, req)
|
GetAdmins(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -675,13 +669,11 @@ func TestHandler_CreateAdmin(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.CreateAdmin(w, req)
|
CreateAdmin(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -767,13 +759,11 @@ func TestHandler_DeleteAdmin(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
|
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.DeleteAdmin(w, req)
|
DeleteAdmin(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
||||||
|
@ -912,13 +902,11 @@ func TestHandler_UpdateAdmin(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.UpdateAdmin(w, req)
|
UpdateAdmin(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
|
@ -1,50 +1,58 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handler is the Admin API request handler.
|
// Handler is the Admin API request handler.
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
adminDB admin.DB
|
acmeResponder ACMEAdminResponder
|
||||||
auth adminAuthority
|
policyResponder PolicyAdminResponder
|
||||||
acmeDB acme.DB
|
}
|
||||||
acmeResponder acmeAdminResponderInterface
|
|
||||||
policyResponder policyAdminResponderInterface
|
// Route traffic and implement the Router interface.
|
||||||
|
//
|
||||||
|
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
|
||||||
|
func (h *Handler) Route(r api.Router) {
|
||||||
|
Route(r, h.acmeResponder, h.policyResponder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler returns a new Authority Config Handler.
|
// NewHandler returns a new Authority Config Handler.
|
||||||
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler {
|
//
|
||||||
|
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
|
||||||
|
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) api.RouterHandler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
auth: auth,
|
|
||||||
adminDB: adminDB,
|
|
||||||
acmeDB: acmeDB,
|
|
||||||
acmeResponder: acmeResponder,
|
acmeResponder: acmeResponder,
|
||||||
policyResponder: policyResponder,
|
policyResponder: policyResponder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route traffic and implement the Router interface.
|
var mustAuthority = func(ctx context.Context) adminAuthority {
|
||||||
func (h *Handler) Route(r api.Router) {
|
return authority.MustFromContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route traffic and implement the Router interface.
|
||||||
|
func Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) {
|
||||||
authnz := func(next http.HandlerFunc) http.HandlerFunc {
|
authnz := func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next))
|
return extractAuthorizeTokenAdmin(requireAPIEnabled(next))
|
||||||
}
|
}
|
||||||
|
|
||||||
enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return h.checkAction(next, true)
|
return checkAction(next, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return h.checkAction(next, false)
|
return checkAction(next, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return authnz(h.loadProvisionerByName(h.requireEABEnabled(next)))
|
return authnz(loadProvisionerByName(requireEABEnabled(next)))
|
||||||
}
|
}
|
||||||
|
|
||||||
authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
@ -52,53 +60,58 @@ func (h *Handler) Route(r api.Router) {
|
||||||
}
|
}
|
||||||
|
|
||||||
provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return authnz(disabledInStandalone(h.loadProvisionerByName(next)))
|
return authnz(disabledInStandalone(loadProvisionerByName(next)))
|
||||||
}
|
}
|
||||||
|
|
||||||
acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return authnz(disabledInStandalone(h.loadProvisionerByName(h.requireEABEnabled(h.loadExternalAccountKey(next)))))
|
return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next)))))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Provisioners
|
// Provisioners
|
||||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner))
|
r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner))
|
||||||
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners))
|
r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners))
|
||||||
r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner))
|
r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner))
|
||||||
r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner))
|
r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner))
|
||||||
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner))
|
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner))
|
||||||
|
|
||||||
// Admins
|
// Admins
|
||||||
r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin))
|
r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin))
|
||||||
r.MethodFunc("GET", "/admins", authnz(h.GetAdmins))
|
r.MethodFunc("GET", "/admins", authnz(GetAdmins))
|
||||||
r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin))
|
r.MethodFunc("POST", "/admins", authnz(CreateAdmin))
|
||||||
r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin))
|
r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin))
|
||||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin))
|
r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin))
|
||||||
|
|
||||||
// ACME External Account Binding Keys
|
// ACME responder
|
||||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys))
|
if acmeResponder != nil {
|
||||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys))
|
// ACME External Account Binding Keys
|
||||||
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.CreateExternalAccountKey))
|
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
|
||||||
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(h.acmeResponder.DeleteExternalAccountKey))
|
r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
|
||||||
|
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey))
|
||||||
|
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey))
|
||||||
|
}
|
||||||
|
|
||||||
// Policy - Authority
|
// Policy responder
|
||||||
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(h.policyResponder.GetAuthorityPolicy))
|
if policyResponder != nil {
|
||||||
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(h.policyResponder.CreateAuthorityPolicy))
|
// Policy - Authority
|
||||||
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(h.policyResponder.UpdateAuthorityPolicy))
|
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy))
|
||||||
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(h.policyResponder.DeleteAuthorityPolicy))
|
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy))
|
||||||
|
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy))
|
||||||
|
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy))
|
||||||
|
|
||||||
// Policy - Provisioner
|
// Policy - Provisioner
|
||||||
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.GetProvisionerPolicy))
|
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy))
|
||||||
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.CreateProvisionerPolicy))
|
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy))
|
||||||
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.UpdateProvisionerPolicy))
|
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy))
|
||||||
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.DeleteProvisionerPolicy))
|
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy))
|
||||||
|
|
||||||
// Policy - ACME Account
|
|
||||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy))
|
|
||||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy))
|
|
||||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy))
|
|
||||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy))
|
|
||||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy))
|
|
||||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy))
|
|
||||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy))
|
|
||||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy))
|
|
||||||
|
|
||||||
|
// Policy - ACME Account
|
||||||
|
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
|
||||||
|
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
|
||||||
|
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
|
||||||
|
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
|
||||||
|
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
|
||||||
|
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
|
||||||
|
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
|
||||||
|
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,10 @@ import (
|
||||||
|
|
||||||
// requireAPIEnabled is a middleware that ensures the Administration API
|
// requireAPIEnabled is a middleware that ensures the Administration API
|
||||||
// is enabled before servicing requests.
|
// is enabled before servicing requests.
|
||||||
func (h *Handler) requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !h.auth.IsAdminAPIEnabled() {
|
if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
|
||||||
"administration API not enabled"))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next(w, r)
|
next(w, r)
|
||||||
|
@ -29,7 +28,7 @@ func (h *Handler) requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
|
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
|
||||||
func (h *Handler) extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
|
func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
tok := r.Header.Get("Authorization")
|
tok := r.Header.Get("Authorization")
|
||||||
|
@ -39,36 +38,39 @@ func (h *Handler) extractAuthorizeTokenAdmin(next http.HandlerFunc) http.Handler
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
adm, err := h.auth.AuthorizeAdminToken(r, tok)
|
ctx := r.Context()
|
||||||
|
adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := linkedca.NewContextWithAdmin(r.Context(), adm)
|
ctx = linkedca.NewContextWithAdmin(ctx, adm)
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadProvisionerByName is a middleware that searches for a provisioner
|
// loadProvisionerByName is a middleware that searches for a provisioner
|
||||||
// by name and stores it in the context.
|
// by name and stores it in the context.
|
||||||
func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
|
func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
name := chi.URLParam(r, "provisionerName")
|
|
||||||
var (
|
var (
|
||||||
p provisioner.Interface
|
p provisioner.Interface
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
auth := mustAuthority(ctx)
|
||||||
|
adminDB := admin.MustFromContext(ctx)
|
||||||
|
name := chi.URLParam(r, "provisionerName")
|
||||||
|
|
||||||
// TODO(hs): distinguish 404 vs. 500
|
// TODO(hs): distinguish 404 vs. 500
|
||||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
prov, err := adminDB.GetProvisioner(ctx, p.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
|
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
|
||||||
return
|
return
|
||||||
|
@ -80,9 +82,8 @@ func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkAction checks if an action is supported in standalone or not
|
// checkAction checks if an action is supported in standalone or not
|
||||||
func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
|
func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// actions allowed in standalone mode are always supported
|
// actions allowed in standalone mode are always supported
|
||||||
if supportedInStandalone {
|
if supportedInStandalone {
|
||||||
next(w, r)
|
next(w, r)
|
||||||
|
@ -91,7 +92,7 @@ func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool)
|
||||||
|
|
||||||
// when an action is not supported in standalone mode and when
|
// when an action is not supported in standalone mode and when
|
||||||
// using a nosql.DB backend, actions are not supported
|
// using a nosql.DB backend, actions are not supported
|
||||||
if _, ok := h.adminDB.(*nosql.DB); ok {
|
if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||||
"operation not supported in standalone mode"))
|
"operation not supported in standalone mode"))
|
||||||
return
|
return
|
||||||
|
@ -104,10 +105,11 @@ func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool)
|
||||||
|
|
||||||
// loadExternalAccountKey is a middleware that searches for an ACME
|
// loadExternalAccountKey is a middleware that searches for an ACME
|
||||||
// External Account Key by reference or keyID and stores it in the context.
|
// External Account Key by reference or keyID and stores it in the context.
|
||||||
func (h *Handler) loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
|
func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||||
|
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||||
|
|
||||||
reference := chi.URLParam(r, "reference")
|
reference := chi.URLParam(r, "reference")
|
||||||
keyID := chi.URLParam(r, "keyID")
|
keyID := chi.URLParam(r, "keyID")
|
||||||
|
@ -118,9 +120,9 @@ func (h *Handler) loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
if keyID != "" {
|
if keyID != "" {
|
||||||
eak, err = h.acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
|
eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
|
||||||
} else {
|
} else {
|
||||||
eak, err = h.acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
|
eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -71,13 +71,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.requireAPIEnabled(tc.next)(w, req)
|
requireAPIEnabled(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -196,13 +194,10 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
req := tc.req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
|
extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -251,6 +246,7 @@ func TestHandler_loadProvisionerByName(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: err,
|
err: err,
|
||||||
}
|
}
|
||||||
|
@ -326,16 +322,13 @@ func TestHandler_loadProvisionerByName(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
adminDB: tc.adminDB,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.loadProvisionerByName(tc.next)(w, req)
|
loadProvisionerByName(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -405,14 +398,10 @@ func TestHandler_checkAction(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
ctx := admin.NewContext(context.Background(), tc.adminDB)
|
||||||
|
req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx)
|
||||||
adminDB: tc.adminDB,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil)
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.checkAction(tc.next, tc.supportedInStandalone)(w, req)
|
checkAction(tc.next, tc.supportedInStandalone)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -653,14 +642,11 @@ func TestHandler_loadExternalAccountKey(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB)
|
||||||
acmeDB: tc.acmeDB,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil)
|
req := httptest.NewRequest("GET", "/foo", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.loadExternalAccountKey(tc.next)(w, req)
|
loadExternalAccountKey(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
@ -14,7 +15,9 @@ import (
|
||||||
"github.com/smallstep/certificates/authority/policy"
|
"github.com/smallstep/certificates/authority/policy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type policyAdminResponderInterface interface {
|
// PolicyAdminResponder is the interface responsible for writing ACME admin
|
||||||
|
// responses.
|
||||||
|
type PolicyAdminResponder interface {
|
||||||
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||||
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||||
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||||
|
@ -29,39 +32,24 @@ type policyAdminResponderInterface interface {
|
||||||
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PolicyAdminResponder is responsible for writing ACME admin responses
|
// policyAdminResponder implements PolicyAdminResponder.
|
||||||
type PolicyAdminResponder struct {
|
type policyAdminResponder struct{}
|
||||||
auth adminAuthority
|
|
||||||
adminDB admin.DB
|
|
||||||
acmeDB acme.DB
|
|
||||||
isLinkedCA bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewACMEAdminResponder returns a new ACMEAdminResponder
|
// NewACMEAdminResponder returns a new PolicyAdminResponder.
|
||||||
func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB) *PolicyAdminResponder {
|
func NewPolicyAdminResponder() PolicyAdminResponder {
|
||||||
|
return &policyAdminResponder{}
|
||||||
var isLinkedCA bool
|
|
||||||
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok {
|
|
||||||
isLinkedCA = a.IsLinkedCA()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PolicyAdminResponder{
|
|
||||||
auth: auth,
|
|
||||||
adminDB: adminDB,
|
|
||||||
acmeDB: acmeDB,
|
|
||||||
isLinkedCA: isLinkedCA,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthorityPolicy handles the GET /admin/authority/policy request
|
// GetAuthorityPolicy handles the GET /admin/authority/policy request
|
||||||
func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
authorityPolicy, err := par.auth.GetAuthorityPolicy(r.Context())
|
auth := mustAuthority(ctx)
|
||||||
|
authorityPolicy, err := auth.GetAuthorityPolicy(r.Context())
|
||||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||||
return
|
return
|
||||||
|
@ -76,15 +64,15 @@ func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
|
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
|
||||||
func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
auth := mustAuthority(ctx)
|
||||||
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx)
|
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||||
|
|
||||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||||
|
@ -113,7 +101,7 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
|
||||||
adm := linkedca.MustAdminFromContext(ctx)
|
adm := linkedca.MustAdminFromContext(ctx)
|
||||||
|
|
||||||
var createdPolicy *linkedca.Policy
|
var createdPolicy *linkedca.Policy
|
||||||
if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||||
if isBadRequest(err) {
|
if isBadRequest(err) {
|
||||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
|
||||||
return
|
return
|
||||||
|
@ -127,15 +115,15 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
|
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
|
||||||
func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
auth := mustAuthority(ctx)
|
||||||
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx)
|
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||||
|
|
||||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||||
|
@ -163,7 +151,7 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
|
||||||
adm := linkedca.MustAdminFromContext(ctx)
|
adm := linkedca.MustAdminFromContext(ctx)
|
||||||
|
|
||||||
var updatedPolicy *linkedca.Policy
|
var updatedPolicy *linkedca.Policy
|
||||||
if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||||
if isBadRequest(err) {
|
if isBadRequest(err) {
|
||||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
|
||||||
return
|
return
|
||||||
|
@ -177,15 +165,15 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
|
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
|
||||||
func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
auth := mustAuthority(ctx)
|
||||||
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx)
|
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||||
|
|
||||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||||
|
@ -197,7 +185,7 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := par.auth.RemoveAuthorityPolicy(ctx); err != nil {
|
if err := auth.RemoveAuthorityPolicy(ctx); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
|
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -206,15 +194,14 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
|
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
|
||||||
func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prov := linkedca.MustProvisionerFromContext(r.Context())
|
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||||
|
|
||||||
provisionerPolicy := prov.GetPolicy()
|
provisionerPolicy := prov.GetPolicy()
|
||||||
if provisionerPolicy == nil {
|
if provisionerPolicy == nil {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||||
|
@ -225,16 +212,14 @@ func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
|
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
|
||||||
func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||||
|
|
||||||
provisionerPolicy := prov.GetPolicy()
|
provisionerPolicy := prov.GetPolicy()
|
||||||
if provisionerPolicy != nil {
|
if provisionerPolicy != nil {
|
||||||
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
|
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
|
||||||
|
@ -256,8 +241,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
prov.Policy = newPolicy
|
prov.Policy = newPolicy
|
||||||
|
auth := mustAuthority(ctx)
|
||||||
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil {
|
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||||
if isBadRequest(err) {
|
if isBadRequest(err) {
|
||||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
|
||||||
return
|
return
|
||||||
|
@ -271,16 +256,14 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
|
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
|
||||||
func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||||
|
|
||||||
provisionerPolicy := prov.GetPolicy()
|
provisionerPolicy := prov.GetPolicy()
|
||||||
if provisionerPolicy == nil {
|
if provisionerPolicy == nil {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||||
|
@ -301,7 +284,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
prov.Policy = newPolicy
|
prov.Policy = newPolicy
|
||||||
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil {
|
auth := mustAuthority(ctx)
|
||||||
|
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||||
if isBadRequest(err) {
|
if isBadRequest(err) {
|
||||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
|
||||||
return
|
return
|
||||||
|
@ -315,16 +299,14 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
|
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
|
||||||
func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||||
|
|
||||||
if prov.Policy == nil {
|
if prov.Policy == nil {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||||
return
|
return
|
||||||
|
@ -333,7 +315,8 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter,
|
||||||
// remove the policy
|
// remove the policy
|
||||||
prov.Policy = nil
|
prov.Policy = nil
|
||||||
|
|
||||||
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil {
|
auth := mustAuthority(ctx)
|
||||||
|
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
|
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -341,16 +324,14 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter,
|
||||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||||
|
|
||||||
eakPolicy := eak.GetPolicy()
|
eakPolicy := eak.GetPolicy()
|
||||||
if eakPolicy == nil {
|
if eakPolicy == nil {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||||
|
@ -360,17 +341,15 @@ func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *
|
||||||
render.ProtoJSONStatus(w, eakPolicy, http.StatusOK)
|
render.ProtoJSONStatus(w, eakPolicy, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||||
|
|
||||||
eakPolicy := eak.GetPolicy()
|
eakPolicy := eak.GetPolicy()
|
||||||
if eakPolicy != nil {
|
if eakPolicy != nil {
|
||||||
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
|
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
|
||||||
|
@ -394,7 +373,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
|
||||||
eak.Policy = newPolicy
|
eak.Policy = newPolicy
|
||||||
|
|
||||||
acmeEAK := linkedEAKToCertificates(eak)
|
acmeEAK := linkedEAKToCertificates(eak)
|
||||||
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||||
|
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
|
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -402,17 +382,15 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
|
||||||
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
|
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||||
|
|
||||||
eakPolicy := eak.GetPolicy()
|
eakPolicy := eak.GetPolicy()
|
||||||
if eakPolicy == nil {
|
if eakPolicy == nil {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||||
|
@ -434,7 +412,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
|
||||||
|
|
||||||
eak.Policy = newPolicy
|
eak.Policy = newPolicy
|
||||||
acmeEAK := linkedEAKToCertificates(eak)
|
acmeEAK := linkedEAKToCertificates(eak)
|
||||||
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||||
|
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
|
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -442,17 +421,15 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
|
||||||
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
|
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
if err := par.blockLinkedCA(); err != nil {
|
if err := blockLinkedCA(ctx); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||||
|
|
||||||
eakPolicy := eak.GetPolicy()
|
eakPolicy := eak.GetPolicy()
|
||||||
if eakPolicy == nil {
|
if eakPolicy == nil {
|
||||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||||
|
@ -463,7 +440,8 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
|
||||||
eak.Policy = nil
|
eak.Policy = nil
|
||||||
|
|
||||||
acmeEAK := linkedEAKToCertificates(eak)
|
acmeEAK := linkedEAKToCertificates(eak)
|
||||||
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||||
|
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
|
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -472,9 +450,10 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
// blockLinkedCA blocks all API operations on linked deployments
|
// blockLinkedCA blocks all API operations on linked deployments
|
||||||
func (par *PolicyAdminResponder) blockLinkedCA() error {
|
func blockLinkedCA(ctx context.Context) error {
|
||||||
// temporary blocking linked deployments
|
// temporary blocking linked deployments
|
||||||
if par.isLinkedCA {
|
adminDB := admin.MustFromContext(ctx)
|
||||||
|
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() {
|
||||||
return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments")
|
return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -109,7 +109,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
|
||||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
|
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
|
||||||
err.Message = "error retrieving authority policy: force"
|
err.Message = "error retrieving authority policy: force"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
||||||
|
@ -124,7 +125,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
|
||||||
err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")
|
err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")
|
||||||
err.Message = "authority policy does not exist"
|
err.Message = "authority policy does not exist"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||||
|
@ -179,7 +181,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return policy, nil
|
return policy, nil
|
||||||
|
@ -234,11 +237,12 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
mockMustAuthority(t, tc.auth)
|
||||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil)
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil)
|
req := httptest.NewRequest("GET", "/foo", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.GetAuthorityPolicy(w, req)
|
par.GetAuthorityPolicy(w, req)
|
||||||
|
@ -301,7 +305,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
||||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
|
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
|
||||||
err.Message = "error retrieving authority policy: force"
|
err.Message = "error retrieving authority policy: force"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
||||||
|
@ -316,7 +321,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
||||||
err := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
|
err := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
|
||||||
err.Message = "authority already has a policy"
|
err.Message = "authority already has a policy"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return &linkedca.Policy{}, nil
|
return &linkedca.Policy{}, nil
|
||||||
|
@ -332,7 +338,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
||||||
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
|
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
|
||||||
body := []byte("{?}")
|
body := []byte("{?}")
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||||
|
@ -358,7 +365,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
||||||
}
|
}
|
||||||
}`)
|
}`)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||||
|
@ -509,11 +517,13 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
mockMustAuthority(t, tc.auth)
|
||||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
|
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.CreateAuthorityPolicy(w, req)
|
par.CreateAuthorityPolicy(w, req)
|
||||||
|
@ -586,7 +596,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
||||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
|
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
|
||||||
err.Message = "error retrieving authority policy: force"
|
err.Message = "error retrieving authority policy: force"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
||||||
|
@ -602,7 +613,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
||||||
err.Message = "authority policy does not exist"
|
err.Message = "authority policy does not exist"
|
||||||
err.Status = http.StatusNotFound
|
err.Status = http.StatusNotFound
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -625,7 +637,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
||||||
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
|
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
|
||||||
body := []byte("{?}")
|
body := []byte("{?}")
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return policy, nil
|
return policy, nil
|
||||||
|
@ -658,7 +671,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
||||||
}
|
}
|
||||||
}`)
|
}`)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return policy, nil
|
return policy, nil
|
||||||
|
@ -809,11 +823,13 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
mockMustAuthority(t, tc.auth)
|
||||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
|
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.UpdateAuthorityPolicy(w, req)
|
par.UpdateAuthorityPolicy(w, req)
|
||||||
|
@ -886,7 +902,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
||||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
|
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
|
||||||
err.Message = "error retrieving authority policy: force"
|
err.Message = "error retrieving authority policy: force"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
return nil, admin.NewError(admin.ErrorServerInternalType, "force")
|
||||||
|
@ -902,7 +919,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
||||||
err.Message = "authority policy does not exist"
|
err.Message = "authority policy does not exist"
|
||||||
err.Status = http.StatusNotFound
|
err.Status = http.StatusNotFound
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -924,7 +942,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
||||||
err := admin.NewErrorISE("error deleting authority policy: force")
|
err := admin.NewErrorISE("error deleting authority policy: force")
|
||||||
err.Message = "error deleting authority policy: force"
|
err.Message = "error deleting authority policy: force"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return policy, nil
|
return policy, nil
|
||||||
|
@ -947,7 +966,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return policy, nil
|
return policy, nil
|
||||||
|
@ -963,11 +983,13 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
mockMustAuthority(t, tc.auth)
|
||||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
|
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.DeleteAuthorityPolicy(w, req)
|
par.DeleteAuthorityPolicy(w, req)
|
||||||
|
@ -1033,6 +1055,7 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) {
|
||||||
err.Message = "provisioner policy does not exist"
|
err.Message = "provisioner policy does not exist"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
}
|
}
|
||||||
|
@ -1085,7 +1108,8 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) {
|
||||||
}
|
}
|
||||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
response: &testPolicyResponse{
|
response: &testPolicyResponse{
|
||||||
X509: &testX509Policy{
|
X509: &testX509Policy{
|
||||||
Allow: &testX509Names{
|
Allow: &testX509Names{
|
||||||
|
@ -1135,11 +1159,13 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
mockMustAuthority(t, tc.auth)
|
||||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
|
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil)
|
req := httptest.NewRequest("GET", "/foo", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.GetProvisionerPolicy(w, req)
|
par.GetProvisionerPolicy(w, req)
|
||||||
|
@ -1214,6 +1240,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
||||||
err.Message = "provisioner provName already has a policy"
|
err.Message = "provisioner provName already has a policy"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 409,
|
statusCode: 409,
|
||||||
}
|
}
|
||||||
|
@ -1228,6 +1255,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
||||||
body := []byte("{?}")
|
body := []byte("{?}")
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
body: body,
|
body: body,
|
||||||
err: adminErr,
|
err: adminErr,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
|
@ -1251,7 +1279,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
||||||
}
|
}
|
||||||
}`)
|
}`)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||||
|
@ -1283,7 +1312,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||||
return &authority.PolicyError{
|
return &authority.PolicyError{
|
||||||
|
@ -1318,7 +1348,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||||
return &authority.PolicyError{
|
return &authority.PolicyError{
|
||||||
|
@ -1351,7 +1382,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||||
return nil
|
return nil
|
||||||
|
@ -1372,11 +1404,12 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
mockMustAuthority(t, tc.auth)
|
||||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil)
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.CreateProvisionerPolicy(w, req)
|
par.CreateProvisionerPolicy(w, req)
|
||||||
|
@ -1452,6 +1485,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
||||||
err.Message = "provisioner policy does not exist"
|
err.Message = "provisioner policy does not exist"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
}
|
}
|
||||||
|
@ -1474,6 +1508,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
||||||
body := []byte("{?}")
|
body := []byte("{?}")
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
body: body,
|
body: body,
|
||||||
err: adminErr,
|
err: adminErr,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
|
@ -1505,7 +1540,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
||||||
}
|
}
|
||||||
}`)
|
}`)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
|
||||||
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
|
||||||
|
@ -1538,7 +1574,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||||
return &authority.PolicyError{
|
return &authority.PolicyError{
|
||||||
|
@ -1574,7 +1611,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||||
return &authority.PolicyError{
|
return &authority.PolicyError{
|
||||||
|
@ -1608,7 +1646,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||||
return nil
|
return nil
|
||||||
|
@ -1629,11 +1668,12 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
mockMustAuthority(t, tc.auth)
|
||||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil)
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.UpdateProvisionerPolicy(w, req)
|
par.UpdateProvisionerPolicy(w, req)
|
||||||
|
@ -1710,6 +1750,7 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
|
||||||
err.Message = "provisioner policy does not exist"
|
err.Message = "provisioner policy does not exist"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
}
|
}
|
||||||
|
@ -1723,7 +1764,8 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
|
||||||
err := admin.NewErrorISE("error deleting provisioner policy: force")
|
err := admin.NewErrorISE("error deleting provisioner policy: force")
|
||||||
err.Message = "error deleting provisioner policy: force"
|
err.Message = "error deleting provisioner policy: force"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||||
return errors.New("force")
|
return errors.New("force")
|
||||||
|
@ -1740,7 +1782,8 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
|
||||||
}
|
}
|
||||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: &mockAdminAuthority{
|
auth: &mockAdminAuthority{
|
||||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||||
return nil
|
return nil
|
||||||
|
@ -1753,11 +1796,13 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
mockMustAuthority(t, tc.auth)
|
||||||
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB)
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
|
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.DeleteProvisionerPolicy(w, req)
|
par.DeleteProvisionerPolicy(w, req)
|
||||||
|
@ -1828,6 +1873,7 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) {
|
||||||
err.Message = "ACME EAK policy does not exist"
|
err.Message = "ACME EAK policy does not exist"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
}
|
}
|
||||||
|
@ -1885,7 +1931,8 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) {
|
||||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||||
ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak)
|
ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
response: &testPolicyResponse{
|
response: &testPolicyResponse{
|
||||||
X509: &testX509Policy{
|
X509: &testX509Policy{
|
||||||
Allow: &testX509Names{
|
Allow: &testX509Names{
|
||||||
|
@ -1935,11 +1982,12 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB)
|
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil)
|
req := httptest.NewRequest("GET", "/foo", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.GetACMEAccountPolicy(w, req)
|
par.GetACMEAccountPolicy(w, req)
|
||||||
|
@ -2018,6 +2066,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
||||||
err.Message = "ACME EAK eakID already has a policy"
|
err.Message = "ACME EAK eakID already has a policy"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 409,
|
statusCode: 409,
|
||||||
}
|
}
|
||||||
|
@ -2036,6 +2085,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
||||||
body := []byte("{?}")
|
body := []byte("{?}")
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
body: body,
|
body: body,
|
||||||
err: adminErr,
|
err: adminErr,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
|
@ -2064,6 +2114,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
||||||
}`)
|
}`)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
body: body,
|
body: body,
|
||||||
err: adminErr,
|
err: adminErr,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
|
@ -2091,7 +2142,8 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
acmeDB: &acme.MockDB{
|
acmeDB: &acme.MockDB{
|
||||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||||
assert.Equal(t, "provID", provisionerID)
|
assert.Equal(t, "provID", provisionerID)
|
||||||
|
@ -2124,7 +2176,8 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
acmeDB: &acme.MockDB{
|
acmeDB: &acme.MockDB{
|
||||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||||
assert.Equal(t, "provID", provisionerID)
|
assert.Equal(t, "provID", provisionerID)
|
||||||
|
@ -2147,11 +2200,12 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB)
|
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.CreateACMEAccountPolicy(w, req)
|
par.CreateACMEAccountPolicy(w, req)
|
||||||
|
@ -2231,6 +2285,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
||||||
err.Message = "ACME EAK policy does not exist"
|
err.Message = "ACME EAK policy does not exist"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
}
|
}
|
||||||
|
@ -2257,6 +2312,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
||||||
body := []byte("{?}")
|
body := []byte("{?}")
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
body: body,
|
body: body,
|
||||||
err: adminErr,
|
err: adminErr,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
|
@ -2293,6 +2349,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
||||||
}`)
|
}`)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
body: body,
|
body: body,
|
||||||
err: adminErr,
|
err: adminErr,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
|
@ -2321,7 +2378,8 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
acmeDB: &acme.MockDB{
|
acmeDB: &acme.MockDB{
|
||||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||||
assert.Equal(t, "provID", provisionerID)
|
assert.Equal(t, "provID", provisionerID)
|
||||||
|
@ -2355,7 +2413,8 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
||||||
body, err := protojson.Marshal(policy)
|
body, err := protojson.Marshal(policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
acmeDB: &acme.MockDB{
|
acmeDB: &acme.MockDB{
|
||||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||||
assert.Equal(t, "provID", provisionerID)
|
assert.Equal(t, "provID", provisionerID)
|
||||||
|
@ -2378,11 +2437,12 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB)
|
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.UpdateACMEAccountPolicy(w, req)
|
par.UpdateACMEAccountPolicy(w, req)
|
||||||
|
@ -2462,6 +2522,7 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
|
||||||
err.Message = "ACME EAK policy does not exist"
|
err.Message = "ACME EAK policy does not exist"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
}
|
}
|
||||||
|
@ -2487,7 +2548,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
|
||||||
err := admin.NewErrorISE("error deleting ACME EAK policy: force")
|
err := admin.NewErrorISE("error deleting ACME EAK policy: force")
|
||||||
err.Message = "error deleting ACME EAK policy: force"
|
err.Message = "error deleting ACME EAK policy: force"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
acmeDB: &acme.MockDB{
|
acmeDB: &acme.MockDB{
|
||||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||||
assert.Equal(t, "provID", provisionerID)
|
assert.Equal(t, "provID", provisionerID)
|
||||||
|
@ -2518,7 +2580,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
|
||||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||||
ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak)
|
ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
acmeDB: &acme.MockDB{
|
acmeDB: &acme.MockDB{
|
||||||
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||||
assert.Equal(t, "provID", provisionerID)
|
assert.Equal(t, "provID", provisionerID)
|
||||||
|
@ -2533,11 +2596,12 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB)
|
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
|
||||||
|
par := NewPolicyAdminResponder()
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
par.DeleteACMEAccountPolicy(w, req)
|
par.DeleteACMEAccountPolicy(w, req)
|
||||||
|
|
|
@ -23,29 +23,31 @@ type GetProvisionersResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProvisioner returns the requested provisioner, or an error.
|
// GetProvisioner returns the requested provisioner, or an error.
|
||||||
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
func GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
|
||||||
|
|
||||||
id := r.URL.Query().Get("id")
|
|
||||||
name := chi.URLParam(r, "name")
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
p provisioner.Interface
|
p provisioner.Interface
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
id := r.URL.Query().Get("id")
|
||||||
|
name := chi.URLParam(r, "name")
|
||||||
|
auth := mustAuthority(ctx)
|
||||||
|
db := admin.MustFromContext(ctx)
|
||||||
|
|
||||||
if len(id) > 0 {
|
if len(id) > 0 {
|
||||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
prov, err := db.GetProvisioner(ctx, p.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
|
@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProvisioners returns the given segment of provisioners associated with the authority.
|
// GetProvisioners returns the given segment of provisioners associated with the authority.
|
||||||
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
func GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||||
cursor, limit, err := api.ParseCursor(r)
|
cursor, limit, err := api.ParseCursor(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||||
|
@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p, next, err := h.auth.GetProvisioners(cursor, limit)
|
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateProvisioner creates a new prov.
|
// CreateProvisioner creates a new prov.
|
||||||
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
var prov = new(linkedca.Provisioner)
|
var prov = new(linkedca.Provisioner)
|
||||||
if err := read.ProtoJSON(r.Body, prov); err != nil {
|
if err := read.ProtoJSON(r.Body, prov); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
|
@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
|
if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteProvisioner deletes a provisioner.
|
// DeleteProvisioner deletes a provisioner.
|
||||||
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
id := r.URL.Query().Get("id")
|
|
||||||
name := chi.URLParam(r, "name")
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
p provisioner.Interface
|
p provisioner.Interface
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
id := r.URL.Query().Get("id")
|
||||||
|
name := chi.URLParam(r, "name")
|
||||||
|
auth := mustAuthority(r.Context())
|
||||||
|
|
||||||
if len(id) > 0 {
|
if len(id) > 0 {
|
||||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProvisioner updates an existing prov.
|
// UpdateProvisioner updates an existing prov.
|
||||||
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
var nu = new(linkedca.Provisioner)
|
var nu = new(linkedca.Provisioner)
|
||||||
if err := read.ProtoJSON(r.Body, nu); err != nil {
|
if err := read.ProtoJSON(r.Body, nu); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
name := chi.URLParam(r, "name")
|
name := chi.URLParam(r, "name")
|
||||||
_old, err := h.auth.LoadProvisionerByName(name)
|
auth := mustAuthority(ctx)
|
||||||
|
db := admin.MustFromContext(ctx)
|
||||||
|
|
||||||
|
p, err := auth.LoadProvisionerByName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
|
old, err := db.GetProvisioner(r.Context(), p.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
if err := auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,6 +50,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
req: req,
|
req: req,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: &admin.Error{
|
err: &admin.Error{
|
||||||
Type: admin.ErrorServerInternalType.String(),
|
Type: admin.ErrorServerInternalType.String(),
|
||||||
|
@ -74,6 +75,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
req: req,
|
req: req,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: &admin.Error{
|
err: &admin.Error{
|
||||||
Type: admin.ErrorServerInternalType.String(),
|
Type: admin.ErrorServerInternalType.String(),
|
||||||
|
@ -156,13 +158,11 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
adminDB: tc.adminDB,
|
req := tc.req.WithContext(ctx)
|
||||||
}
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetProvisioner(w, req)
|
GetProvisioner(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -280,12 +280,10 @@ func TestHandler_GetProvisioners(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
req := tc.req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetProvisioners(w, req)
|
GetProvisioners(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -405,13 +403,11 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.CreateProvisioner(w, req)
|
CreateProvisioner(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -571,12 +567,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
req := tc.req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.DeleteProvisioner(w, req)
|
DeleteProvisioner(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -625,6 +619,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
body: body,
|
body: body,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: &admin.Error{
|
err: &admin.Error{
|
||||||
Type: "badRequest",
|
Type: "badRequest",
|
||||||
|
@ -654,6 +649,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
body: body,
|
body: body,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: auth,
|
auth: auth,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: &admin.Error{
|
err: &admin.Error{
|
||||||
|
@ -1061,14 +1057,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
adminDB: tc.adminDB,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.UpdateProvisioner(w, req)
|
UpdateProvisioner(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
|
@ -76,6 +76,29 @@ type DB interface {
|
||||||
DeleteAuthorityPolicy(ctx context.Context) error
|
DeleteAuthorityPolicy(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dbKey struct{}
|
||||||
|
|
||||||
|
// NewContext adds the given admin database to the context.
|
||||||
|
func NewContext(ctx context.Context, db DB) context.Context {
|
||||||
|
return context.WithValue(ctx, dbKey{}, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext returns the current admin database from the given context.
|
||||||
|
func FromContext(ctx context.Context) (db DB, ok bool) {
|
||||||
|
db, ok = ctx.Value(dbKey{}).(DB)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFromContext returns the current admin database from the given context. It
|
||||||
|
// will panic if it's not in the context.
|
||||||
|
func MustFromContext(ctx context.Context) DB {
|
||||||
|
if db, ok := FromContext(ctx); !ok {
|
||||||
|
panic("admin database is not in the context")
|
||||||
|
} else {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MockDB is an implementation of the DB interface that should only be used as
|
// MockDB is an implementation of the DB interface that should only be used as
|
||||||
// a mock in tests.
|
// a mock in tests.
|
||||||
type MockDB struct {
|
type MockDB struct {
|
||||||
|
|
|
@ -167,6 +167,29 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type authorityKey struct{}
|
||||||
|
|
||||||
|
// NewContext adds the given authority to the context.
|
||||||
|
func NewContext(ctx context.Context, a *Authority) context.Context {
|
||||||
|
return context.WithValue(ctx, authorityKey{}, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext returns the current authority from the given context.
|
||||||
|
func FromContext(ctx context.Context) (a *Authority, ok bool) {
|
||||||
|
a, ok = ctx.Value(authorityKey{}).(*Authority)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFromContext returns the current authority from the given context. It will
|
||||||
|
// panic if the authority is not in the context.
|
||||||
|
func MustFromContext(ctx context.Context) *Authority {
|
||||||
|
if a, ok := FromContext(ctx); !ok {
|
||||||
|
panic("authority is not in the context")
|
||||||
|
} else {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ReloadAdminResources reloads admins and provisioners from the DB.
|
// ReloadAdminResources reloads admins and provisioners from the DB.
|
||||||
func (a *Authority) ReloadAdminResources(ctx context.Context) error {
|
func (a *Authority) ReloadAdminResources(ctx context.Context) error {
|
||||||
var (
|
var (
|
||||||
|
@ -235,6 +258,7 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
ctx := NewContext(context.Background(), a)
|
||||||
|
|
||||||
// Set password if they are not set.
|
// Set password if they are not set.
|
||||||
var configPassword []byte
|
var configPassword []byte
|
||||||
|
@ -270,7 +294,7 @@ func (a *Authority) init() error {
|
||||||
if a.config.KMS != nil {
|
if a.config.KMS != nil {
|
||||||
options = *a.config.KMS
|
options = *a.config.KMS
|
||||||
}
|
}
|
||||||
a.keyManager, err = kms.New(context.Background(), options)
|
a.keyManager, err = kms.New(ctx, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -300,7 +324,7 @@ func (a *Authority) init() error {
|
||||||
|
|
||||||
// Configure linked RA
|
// Configure linked RA
|
||||||
if linkedcaClient != nil && options.CertificateAuthority == "" {
|
if linkedcaClient != nil && options.CertificateAuthority == "" {
|
||||||
conf, err := linkedcaClient.GetConfiguration(context.Background())
|
conf, err := linkedcaClient.GetConfiguration(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -334,7 +358,7 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.x509CAService, err = cas.New(context.Background(), options)
|
a.x509CAService, err = cas.New(ctx, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -521,7 +545,7 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.scepService, err = scep.NewService(context.Background(), options)
|
a.scepService, err = scep.NewService(ctx, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -543,19 +567,19 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
provs, err := a.adminDB.GetProvisioners(context.Background())
|
provs, err := a.adminDB.GetProvisioners(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
||||||
}
|
}
|
||||||
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
|
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
|
||||||
// Create First Provisioner
|
// Create First Provisioner
|
||||||
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password))
|
prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return admin.WrapErrorISE(err, "error creating first provisioner")
|
return admin.WrapErrorISE(err, "error creating first provisioner")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create first admin
|
// Create first admin
|
||||||
if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{
|
if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{
|
||||||
ProvisionerId: prov.Id,
|
ProvisionerId: prov.Id,
|
||||||
Subject: "step",
|
Subject: "step",
|
||||||
Type: linkedca.Admin_SUPER_ADMIN,
|
Type: linkedca.Admin_SUPER_ADMIN,
|
||||||
|
@ -571,7 +595,7 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load x509 and SSH Policy Engines
|
// Load x509 and SSH Policy Engines
|
||||||
if err := a.reloadPolicyEngines(context.Background()); err != nil {
|
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -596,6 +620,15 @@ func (a *Authority) init() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetID returns the define authority id or a zero uuid.
|
||||||
|
func (a *Authority) GetID() string {
|
||||||
|
const zeroUUID = "00000000-0000-0000-0000-000000000000"
|
||||||
|
if id := a.config.AuthorityConfig.AuthorityID; id != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return zeroUUID
|
||||||
|
}
|
||||||
|
|
||||||
// GetDatabase returns the authority database. If the configuration does not
|
// GetDatabase returns the authority database. If the configuration does not
|
||||||
// define a database, GetDatabase will return a db.SimpleDB instance.
|
// define a database, GetDatabase will return a db.SimpleDB instance.
|
||||||
func (a *Authority) GetDatabase() db.AuthDB {
|
func (a *Authority) GetDatabase() db.AuthDB {
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthority_GetID(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
authorityID string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"ok", fields{""}, "00000000-0000-0000-0000-000000000000"},
|
||||||
|
{"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &Authority{
|
||||||
|
config: &config.Config{
|
||||||
|
AuthorityConfig: &config.AuthConfig{
|
||||||
|
AuthorityID: tt.fields.authorityID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if got := a.GetID(); got != tt.want {
|
||||||
|
t.Errorf("Authority.GetID() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -251,8 +251,7 @@ func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisio
|
||||||
// AuthorizeSign authorizes a signature request by validating and authenticating
|
// AuthorizeSign authorizes a signature request by validating and authenticating
|
||||||
// a token that must be sent w/ the request.
|
// a token that must be sent w/ the request.
|
||||||
//
|
//
|
||||||
// NOTE: This method is deprecated and should not be used. We make it available
|
// Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error).
|
||||||
// in the short term os as not to break existing clients.
|
|
||||||
func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
|
func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
|
||||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||||
return a.Authorize(ctx, token)
|
return a.Authorize(ctx, token)
|
||||||
|
|
|
@ -54,7 +54,11 @@ func startCABootstrapServer() *httptest.Server {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
baseContext := buildContext(ca.auth, nil, nil, nil)
|
||||||
srv.Config.Handler = ca.srv.Handler
|
srv.Config.Handler = ca.srv.Handler
|
||||||
|
srv.Config.BaseContext = func(net.Listener) context.Context {
|
||||||
|
return baseContext
|
||||||
|
}
|
||||||
srv.TLS = ca.srv.TLSConfig
|
srv.TLS = ca.srv.TLSConfig
|
||||||
srv.StartTLS()
|
srv.StartTLS()
|
||||||
// Force the use of GetCertificate on IPs
|
// Force the use of GetCertificate on IPs
|
||||||
|
|
80
ca/ca.go
80
ca/ca.go
|
@ -1,10 +1,12 @@
|
||||||
package ca
|
package ca
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -18,6 +20,7 @@ import (
|
||||||
acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql"
|
acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
adminAPI "github.com/smallstep/certificates/authority/admin/api"
|
adminAPI "github.com/smallstep/certificates/authority/admin/api"
|
||||||
"github.com/smallstep/certificates/authority/config"
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
|
@ -170,10 +173,9 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||||
insecureHandler := http.Handler(insecureMux)
|
insecureHandler := http.Handler(insecureMux)
|
||||||
|
|
||||||
// Add regular CA api endpoints in / and /1.0
|
// Add regular CA api endpoints in / and /1.0
|
||||||
routerHandler := api.New(auth)
|
api.Route(mux)
|
||||||
routerHandler.Route(mux)
|
|
||||||
mux.Route("/1.0", func(r chi.Router) {
|
mux.Route("/1.0", func(r chi.Router) {
|
||||||
routerHandler.Route(r)
|
api.Route(r)
|
||||||
})
|
})
|
||||||
|
|
||||||
//Add ACME api endpoints in /acme and /1.0/acme
|
//Add ACME api endpoints in /acme and /1.0/acme
|
||||||
|
@ -187,49 +189,41 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||||
dns = fmt.Sprintf("%s:%s", dns, port)
|
dns = fmt.Sprintf("%s:%s", dns, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ACME Router
|
// ACME Router is only available if we have a database.
|
||||||
prefix := "acme"
|
|
||||||
var acmeDB acme.DB
|
var acmeDB acme.DB
|
||||||
if cfg.DB == nil {
|
var acmeLinker acme.Linker
|
||||||
acmeDB = nil
|
if cfg.DB != nil {
|
||||||
} else {
|
|
||||||
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
|
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error configuring ACME DB interface")
|
return nil, errors.Wrap(err, "error configuring ACME DB interface")
|
||||||
}
|
}
|
||||||
|
acmeLinker = acme.NewLinker(dns, "acme")
|
||||||
|
mux.Route("/acme", func(r chi.Router) {
|
||||||
|
acmeAPI.Route(r)
|
||||||
|
})
|
||||||
|
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
|
||||||
|
// of the ACME spec.
|
||||||
|
mux.Route("/2.0/acme", func(r chi.Router) {
|
||||||
|
acmeAPI.Route(r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
|
|
||||||
Backdate: *cfg.AuthorityConfig.Backdate,
|
|
||||||
DB: acmeDB,
|
|
||||||
DNS: dns,
|
|
||||||
Prefix: prefix,
|
|
||||||
CA: auth,
|
|
||||||
})
|
|
||||||
mux.Route("/"+prefix, func(r chi.Router) {
|
|
||||||
acmeHandler.Route(r)
|
|
||||||
})
|
|
||||||
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
|
|
||||||
// of the ACME spec.
|
|
||||||
mux.Route("/2.0/"+prefix, func(r chi.Router) {
|
|
||||||
acmeHandler.Route(r)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Admin API Router
|
// Admin API Router
|
||||||
if cfg.AuthorityConfig.EnableAdmin {
|
if cfg.AuthorityConfig.EnableAdmin {
|
||||||
adminDB := auth.GetAdminDatabase()
|
adminDB := auth.GetAdminDatabase()
|
||||||
if adminDB != nil {
|
if adminDB != nil {
|
||||||
acmeAdminResponder := adminAPI.NewACMEAdminResponder()
|
acmeAdminResponder := adminAPI.NewACMEAdminResponder()
|
||||||
policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB, acmeDB)
|
policyAdminResponder := adminAPI.NewPolicyAdminResponder()
|
||||||
adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder, policyAdminResponder)
|
|
||||||
mux.Route("/admin", func(r chi.Router) {
|
mux.Route("/admin", func(r chi.Router) {
|
||||||
adminHandler.Route(r)
|
adminAPI.Route(r, acmeAdminResponder, policyAdminResponder)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var scepAuthority *scep.Authority
|
||||||
if ca.shouldServeSCEPEndpoints() {
|
if ca.shouldServeSCEPEndpoints() {
|
||||||
scepPrefix := "scep"
|
scepPrefix := "scep"
|
||||||
scepAuthority, err := scep.New(auth, scep.AuthorityOptions{
|
scepAuthority, err = scep.New(auth, scep.AuthorityOptions{
|
||||||
Service: auth.GetSCEPService(),
|
Service: auth.GetSCEPService(),
|
||||||
DNS: dns,
|
DNS: dns,
|
||||||
Prefix: scepPrefix,
|
Prefix: scepPrefix,
|
||||||
|
@ -237,13 +231,12 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error creating SCEP authority")
|
return nil, errors.Wrap(err, "error creating SCEP authority")
|
||||||
}
|
}
|
||||||
scepRouterHandler := scepAPI.New(scepAuthority)
|
|
||||||
|
|
||||||
// According to the RFC (https://tools.ietf.org/html/rfc8894#section-7.10),
|
// According to the RFC (https://tools.ietf.org/html/rfc8894#section-7.10),
|
||||||
// SCEP operations are performed using HTTP, so that's why the API is mounted
|
// SCEP operations are performed using HTTP, so that's why the API is mounted
|
||||||
// to the insecure mux.
|
// to the insecure mux.
|
||||||
insecureMux.Route("/"+scepPrefix, func(r chi.Router) {
|
insecureMux.Route("/"+scepPrefix, func(r chi.Router) {
|
||||||
scepRouterHandler.Route(r)
|
scepAPI.Route(r)
|
||||||
})
|
})
|
||||||
|
|
||||||
// The RFC also mentions usage of HTTPS, but seems to advise
|
// The RFC also mentions usage of HTTPS, but seems to advise
|
||||||
|
@ -253,7 +246,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||||
// as well as HTTPS can be used to request certificates
|
// as well as HTTPS can be used to request certificates
|
||||||
// using SCEP.
|
// using SCEP.
|
||||||
mux.Route("/"+scepPrefix, func(r chi.Router) {
|
mux.Route("/"+scepPrefix, func(r chi.Router) {
|
||||||
scepRouterHandler.Route(r)
|
scepAPI.Route(r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -280,7 +273,13 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||||
insecureHandler = logger.Middleware(insecureHandler)
|
insecureHandler = logger.Middleware(insecureHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create context with all the necessary values.
|
||||||
|
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
|
||||||
|
|
||||||
ca.srv = server.New(cfg.Address, handler, tlsConfig)
|
ca.srv = server.New(cfg.Address, handler, tlsConfig)
|
||||||
|
ca.srv.BaseContext = func(net.Listener) context.Context {
|
||||||
|
return baseContext
|
||||||
|
}
|
||||||
|
|
||||||
// only start the insecure server if the insecure address is configured
|
// only start the insecure server if the insecure address is configured
|
||||||
// and, currently, also only when it should serve SCEP endpoints.
|
// and, currently, also only when it should serve SCEP endpoints.
|
||||||
|
@ -290,11 +289,32 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||||
// will probably introduce more complexity in terms of graceful
|
// will probably introduce more complexity in terms of graceful
|
||||||
// reload.
|
// reload.
|
||||||
ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil)
|
ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil)
|
||||||
|
ca.insecureSrv.BaseContext = func(net.Listener) context.Context {
|
||||||
|
return baseContext
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ca, nil
|
return ca, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildContext builds the server base context.
|
||||||
|
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context {
|
||||||
|
ctx := authority.NewContext(context.Background(), a)
|
||||||
|
if authDB := a.GetDatabase(); authDB != nil {
|
||||||
|
ctx = db.NewContext(ctx, authDB)
|
||||||
|
}
|
||||||
|
if adminDB := a.GetAdminDatabase(); adminDB != nil {
|
||||||
|
ctx = admin.NewContext(ctx, adminDB)
|
||||||
|
}
|
||||||
|
if scepAuthority != nil {
|
||||||
|
ctx = scep.NewContext(ctx, scepAuthority)
|
||||||
|
}
|
||||||
|
if acmeDB != nil {
|
||||||
|
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil)
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
// Run starts the CA calling to the server ListenAndServe method.
|
// Run starts the CA calling to the server ListenAndServe method.
|
||||||
func (ca *CA) Run() error {
|
func (ca *CA) Run() error {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
|
@ -2,6 +2,7 @@ package ca
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
|
@ -281,7 +282,8 @@ ZEp7knvU2psWRw==
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||||
|
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||||
|
|
||||||
if assert.Equals(t, rr.Code, tc.status) {
|
if assert.Equals(t, rr.Code, tc.status) {
|
||||||
body := &ClosingBuffer{rr.Body}
|
body := &ClosingBuffer{rr.Body}
|
||||||
|
@ -360,7 +362,8 @@ func TestCAProvisioners(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||||
|
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||||
|
|
||||||
if assert.Equals(t, rr.Code, tc.status) {
|
if assert.Equals(t, rr.Code, tc.status) {
|
||||||
body := &ClosingBuffer{rr.Body}
|
body := &ClosingBuffer{rr.Body}
|
||||||
|
@ -426,7 +429,8 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||||
|
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||||
|
|
||||||
if assert.Equals(t, rr.Code, tc.status) {
|
if assert.Equals(t, rr.Code, tc.status) {
|
||||||
body := &ClosingBuffer{rr.Body}
|
body := &ClosingBuffer{rr.Body}
|
||||||
|
@ -487,7 +491,8 @@ func TestCARoot(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||||
|
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||||
|
|
||||||
if assert.Equals(t, rr.Code, tc.status) {
|
if assert.Equals(t, rr.Code, tc.status) {
|
||||||
body := &ClosingBuffer{rr.Body}
|
body := &ClosingBuffer{rr.Body}
|
||||||
|
@ -534,7 +539,8 @@ func TestCAHealth(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||||
|
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||||
|
|
||||||
if assert.Equals(t, rr.Code, tc.status) {
|
if assert.Equals(t, rr.Code, tc.status) {
|
||||||
body := &ClosingBuffer{rr.Body}
|
body := &ClosingBuffer{rr.Body}
|
||||||
|
@ -628,7 +634,8 @@ func TestCARenew(t *testing.T) {
|
||||||
rq.TLS = tc.tlsConnState
|
rq.TLS = tc.tlsConnState
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
tc.ca.srv.Handler.ServeHTTP(rr, rq)
|
ctx := authority.NewContext(context.Background(), tc.ca.auth)
|
||||||
|
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
|
||||||
|
|
||||||
if assert.Equals(t, rr.Code, tc.status) {
|
if assert.Equals(t, rr.Code, tc.status) {
|
||||||
body := &ClosingBuffer{rr.Body}
|
body := &ClosingBuffer{rr.Body}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -77,7 +78,12 @@ func startCATestServer() *httptest.Server {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
// Use a httptest.Server instead
|
// Use a httptest.Server instead
|
||||||
return startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
|
srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
|
||||||
|
baseContext := buildContext(ca.auth, nil, nil, nil)
|
||||||
|
srv.Config.BaseContext = func(net.Listener) context.Context {
|
||||||
|
return baseContext
|
||||||
|
}
|
||||||
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {
|
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {
|
||||||
|
|
24
db/db.go
24
db/db.go
|
@ -1,6 +1,7 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -56,6 +57,29 @@ type AuthDB interface {
|
||||||
Shutdown() error
|
Shutdown() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dbKey struct{}
|
||||||
|
|
||||||
|
// NewContext adds the given authority database to the context.
|
||||||
|
func NewContext(ctx context.Context, db AuthDB) context.Context {
|
||||||
|
return context.WithValue(ctx, dbKey{}, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext returns the current authority database from the given context.
|
||||||
|
func FromContext(ctx context.Context) (db AuthDB, ok bool) {
|
||||||
|
db, ok = ctx.Value(dbKey{}).(AuthDB)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFromContext returns the current database from the given context. It
|
||||||
|
// will panic if it's not in the context.
|
||||||
|
func MustFromContext(ctx context.Context) AuthDB {
|
||||||
|
if db, ok := FromContext(ctx); !ok {
|
||||||
|
panic("authority database is not in the context")
|
||||||
|
} else {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CertificateStorer is an extension of AuthDB that allows to store
|
// CertificateStorer is an extension of AuthDB that allows to store
|
||||||
// certificates.
|
// certificates.
|
||||||
type CertificateStorer interface {
|
type CertificateStorer interface {
|
||||||
|
|
134
scep/api/api.go
134
scep/api/api.go
|
@ -38,8 +38,8 @@ type request struct {
|
||||||
Message []byte
|
Message []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// response is a SCEP server response.
|
// Response is a SCEP server Response.
|
||||||
type response struct {
|
type Response struct {
|
||||||
Operation string
|
Operation string
|
||||||
CACertNum int
|
CACertNum int
|
||||||
Data []byte
|
Data []byte
|
||||||
|
@ -52,25 +52,48 @@ type handler struct {
|
||||||
auth *scep.Authority
|
auth *scep.Authority
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Route traffic and implement the Router interface.
|
||||||
|
//
|
||||||
|
// Deprecated: use scep.Route(r api.Router)
|
||||||
|
func (h *handler) Route(r api.Router) {
|
||||||
|
route(r, func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := scep.NewContext(r.Context(), h.auth)
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// New returns a new SCEP API router.
|
// New returns a new SCEP API router.
|
||||||
|
//
|
||||||
|
// Deprecated: use scep.Route(r api.Router)
|
||||||
func New(auth *scep.Authority) api.RouterHandler {
|
func New(auth *scep.Authority) api.RouterHandler {
|
||||||
return &handler{
|
return &handler{auth: auth}
|
||||||
auth: auth,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route traffic and implement the Router interface.
|
// Route traffic and implement the Router interface.
|
||||||
func (h *handler) Route(r api.Router) {
|
func Route(r api.Router) {
|
||||||
getLink := h.auth.GetLinkExplicit
|
route(r, nil)
|
||||||
r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get))
|
}
|
||||||
r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get))
|
|
||||||
r.MethodFunc(http.MethodPost, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Post))
|
func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc) {
|
||||||
r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post))
|
getHandler := lookupProvisioner(Get)
|
||||||
|
postHandler := lookupProvisioner(Post)
|
||||||
|
|
||||||
|
// For backward compatibility.
|
||||||
|
if middleware != nil {
|
||||||
|
getHandler = middleware(getHandler)
|
||||||
|
postHandler = middleware(postHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.MethodFunc(http.MethodGet, "/{provisionerName}/*", getHandler)
|
||||||
|
r.MethodFunc(http.MethodGet, "/{provisionerName}", getHandler)
|
||||||
|
r.MethodFunc(http.MethodPost, "/{provisionerName}/*", postHandler)
|
||||||
|
r.MethodFunc(http.MethodPost, "/{provisionerName}", postHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get handles all SCEP GET requests
|
// Get handles all SCEP GET requests
|
||||||
func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
|
func Get(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
req, err := decodeRequest(r)
|
req, err := decodeRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fail(w, fmt.Errorf("invalid scep get request: %w", err))
|
fail(w, fmt.Errorf("invalid scep get request: %w", err))
|
||||||
|
@ -78,15 +101,15 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
var res response
|
var res Response
|
||||||
|
|
||||||
switch req.Operation {
|
switch req.Operation {
|
||||||
case opnGetCACert:
|
case opnGetCACert:
|
||||||
res, err = h.GetCACert(ctx)
|
res, err = GetCACert(ctx)
|
||||||
case opnGetCACaps:
|
case opnGetCACaps:
|
||||||
res, err = h.GetCACaps(ctx)
|
res, err = GetCACaps(ctx)
|
||||||
case opnPKIOperation:
|
case opnPKIOperation:
|
||||||
res, err = h.PKIOperation(ctx, req)
|
res, err = PKIOperation(ctx, req)
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unknown operation: %s", req.Operation)
|
err = fmt.Errorf("unknown operation: %s", req.Operation)
|
||||||
}
|
}
|
||||||
|
@ -100,20 +123,17 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Post handles all SCEP POST requests
|
// Post handles all SCEP POST requests
|
||||||
func (h *handler) Post(w http.ResponseWriter, r *http.Request) {
|
func Post(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
req, err := decodeRequest(r)
|
req, err := decodeRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fail(w, fmt.Errorf("invalid scep post request: %w", err))
|
fail(w, fmt.Errorf("invalid scep post request: %w", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
var res Response
|
||||||
var res response
|
|
||||||
|
|
||||||
switch req.Operation {
|
switch req.Operation {
|
||||||
case opnPKIOperation:
|
case opnPKIOperation:
|
||||||
res, err = h.PKIOperation(ctx, req)
|
res, err = PKIOperation(r.Context(), req)
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unknown operation: %s", req.Operation)
|
err = fmt.Errorf("unknown operation: %s", req.Operation)
|
||||||
}
|
}
|
||||||
|
@ -127,7 +147,6 @@ func (h *handler) Post(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeRequest(r *http.Request) (request, error) {
|
func decodeRequest(r *http.Request) (request, error) {
|
||||||
|
|
||||||
defer r.Body.Close()
|
defer r.Body.Close()
|
||||||
|
|
||||||
method := r.Method
|
method := r.Method
|
||||||
|
@ -179,9 +198,8 @@ func decodeRequest(r *http.Request) (request, error) {
|
||||||
|
|
||||||
// lookupProvisioner loads the provisioner associated with the request.
|
// lookupProvisioner loads the provisioner associated with the request.
|
||||||
// Responds 404 if the provisioner does not exist.
|
// Responds 404 if the provisioner does not exist.
|
||||||
func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
name := chi.URLParam(r, "provisionerName")
|
name := chi.URLParam(r, "provisionerName")
|
||||||
provisionerName, err := url.PathUnescape(name)
|
provisionerName, err := url.PathUnescape(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -189,7 +207,9 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := h.auth.LoadProvisionerByName(provisionerName)
|
ctx := r.Context()
|
||||||
|
auth := scep.MustFromContext(ctx)
|
||||||
|
p, err := auth.LoadProvisionerByName(provisionerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fail(w, err)
|
fail(w, err)
|
||||||
return
|
return
|
||||||
|
@ -201,25 +221,24 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov))
|
ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov))
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCACert returns the CA certificates in a SCEP response
|
// GetCACert returns the CA certificates in a SCEP response
|
||||||
func (h *handler) GetCACert(ctx context.Context) (response, error) {
|
func GetCACert(ctx context.Context) (Response, error) {
|
||||||
|
auth := scep.MustFromContext(ctx)
|
||||||
certs, err := h.auth.GetCACertificates(ctx)
|
certs, err := auth.GetCACertificates(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return response{}, err
|
return Response{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(certs) == 0 {
|
if len(certs) == 0 {
|
||||||
return response{}, errors.New("missing CA cert")
|
return Response{}, errors.New("missing CA cert")
|
||||||
}
|
}
|
||||||
|
|
||||||
res := response{
|
res := Response{
|
||||||
Operation: opnGetCACert,
|
Operation: opnGetCACert,
|
||||||
CACertNum: len(certs),
|
CACertNum: len(certs),
|
||||||
}
|
}
|
||||||
|
@ -232,7 +251,7 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) {
|
||||||
// not signed or encrypted data has to be returned.
|
// not signed or encrypted data has to be returned.
|
||||||
data, err := microscep.DegenerateCertificates(certs)
|
data, err := microscep.DegenerateCertificates(certs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return response{}, err
|
return Response{}, err
|
||||||
}
|
}
|
||||||
res.Data = data
|
res.Data = data
|
||||||
}
|
}
|
||||||
|
@ -241,11 +260,11 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCACaps returns the CA capabilities in a SCEP response
|
// GetCACaps returns the CA capabilities in a SCEP response
|
||||||
func (h *handler) GetCACaps(ctx context.Context) (response, error) {
|
func GetCACaps(ctx context.Context) (Response, error) {
|
||||||
|
auth := scep.MustFromContext(ctx)
|
||||||
|
caps := auth.GetCACaps(ctx)
|
||||||
|
|
||||||
caps := h.auth.GetCACaps(ctx)
|
res := Response{
|
||||||
|
|
||||||
res := response{
|
|
||||||
Operation: opnGetCACaps,
|
Operation: opnGetCACaps,
|
||||||
Data: formatCapabilities(caps),
|
Data: formatCapabilities(caps),
|
||||||
}
|
}
|
||||||
|
@ -254,13 +273,12 @@ func (h *handler) GetCACaps(ctx context.Context) (response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// PKIOperation performs PKI operations and returns a SCEP response
|
// PKIOperation performs PKI operations and returns a SCEP response
|
||||||
func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) {
|
func PKIOperation(ctx context.Context, req request) (Response, error) {
|
||||||
|
|
||||||
// parse the message using microscep implementation
|
// parse the message using microscep implementation
|
||||||
microMsg, err := microscep.ParsePKIMessage(req.Message)
|
microMsg, err := microscep.ParsePKIMessage(req.Message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// return the error, because we can't use the msg for creating a CertRep
|
// return the error, because we can't use the msg for creating a CertRep
|
||||||
return response{}, err
|
return Response{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is essentially doing the same as microscep.ParsePKIMessage, but
|
// this is essentially doing the same as microscep.ParsePKIMessage, but
|
||||||
|
@ -268,7 +286,7 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
|
||||||
// wrapper for the microscep implementation.
|
// wrapper for the microscep implementation.
|
||||||
p7, err := pkcs7.Parse(microMsg.Raw)
|
p7, err := pkcs7.Parse(microMsg.Raw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return response{}, err
|
return Response{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy over properties to our internal PKIMessage
|
// copy over properties to our internal PKIMessage
|
||||||
|
@ -280,8 +298,9 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
|
||||||
P7: p7,
|
P7: p7,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.auth.DecryptPKIEnvelope(ctx, msg); err != nil {
|
auth := scep.MustFromContext(ctx)
|
||||||
return response{}, err
|
if err := auth.DecryptPKIEnvelope(ctx, msg); err != nil {
|
||||||
|
return Response{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: at this point we have sufficient information for returning nicely signed CertReps
|
// NOTE: at this point we have sufficient information for returning nicely signed CertReps
|
||||||
|
@ -293,13 +312,13 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
|
||||||
// a certificate exists; then it will use RenewalReq. Adding the challenge check here may be a small breaking change for clients.
|
// a certificate exists; then it will use RenewalReq. Adding the challenge check here may be a small breaking change for clients.
|
||||||
// We'll have to see how it works out.
|
// We'll have to see how it works out.
|
||||||
if msg.MessageType == microscep.PKCSReq || msg.MessageType == microscep.RenewalReq {
|
if msg.MessageType == microscep.PKCSReq || msg.MessageType == microscep.RenewalReq {
|
||||||
challengeMatches, err := h.auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword)
|
challengeMatches, err := auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password"))
|
return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password"))
|
||||||
}
|
}
|
||||||
if !challengeMatches {
|
if !challengeMatches {
|
||||||
// TODO: can this be returned safely to the client? In the end, if the password was correct, that gains a bit of info too.
|
// TODO: can this be returned safely to the client? In the end, if the password was correct, that gains a bit of info too.
|
||||||
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided"))
|
return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -311,12 +330,12 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
|
||||||
// Authentication by the (self-signed) certificate with an optional challenge is required; supporting renewals incl. verification
|
// Authentication by the (self-signed) certificate with an optional challenge is required; supporting renewals incl. verification
|
||||||
// of the client cert is not.
|
// of the client cert is not.
|
||||||
|
|
||||||
certRep, err := h.auth.SignCSR(ctx, csr, msg)
|
certRep, err := auth.SignCSR(ctx, csr, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err))
|
return createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
res := response{
|
res := Response{
|
||||||
Operation: opnPKIOperation,
|
Operation: opnPKIOperation,
|
||||||
Data: certRep.Raw,
|
Data: certRep.Raw,
|
||||||
Certificate: certRep.Certificate,
|
Certificate: certRep.Certificate,
|
||||||
|
@ -330,7 +349,7 @@ func formatCapabilities(caps []string) []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeResponse writes a SCEP response back to the SCEP client.
|
// writeResponse writes a SCEP response back to the SCEP client.
|
||||||
func writeResponse(w http.ResponseWriter, res response) {
|
func writeResponse(w http.ResponseWriter, res Response) {
|
||||||
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
log.Error(w, res.Error)
|
log.Error(w, res.Error)
|
||||||
|
@ -350,19 +369,20 @@ func fail(w http.ResponseWriter, err error) {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) {
|
func createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (Response, error) {
|
||||||
certRepMsg, err := h.auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error())
|
auth := scep.MustFromContext(ctx)
|
||||||
|
certRepMsg, err := auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return response{}, err
|
return Response{}, err
|
||||||
}
|
}
|
||||||
return response{
|
return Response{
|
||||||
Operation: opnPKIOperation,
|
Operation: opnPKIOperation,
|
||||||
Data: certRepMsg.Raw,
|
Data: certRepMsg.Raw,
|
||||||
Error: failError,
|
Error: failError,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func contentHeader(r response) string {
|
func contentHeader(r Response) string {
|
||||||
switch r.Operation {
|
switch r.Operation {
|
||||||
default:
|
default:
|
||||||
return "text/plain"
|
return "text/plain"
|
||||||
|
|
|
@ -27,6 +27,29 @@ type Authority struct {
|
||||||
signAuth SignAuthority
|
signAuth SignAuthority
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type authorityKey struct{}
|
||||||
|
|
||||||
|
// NewContext adds the given authority to the context.
|
||||||
|
func NewContext(ctx context.Context, a *Authority) context.Context {
|
||||||
|
return context.WithValue(ctx, authorityKey{}, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext returns the current authority from the given context.
|
||||||
|
func FromContext(ctx context.Context) (a *Authority, ok bool) {
|
||||||
|
a, ok = ctx.Value(authorityKey{}).(*Authority)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFromContext returns the current authority from the given context. It will
|
||||||
|
// panic if the authority is not in the context.
|
||||||
|
func MustFromContext(ctx context.Context) *Authority {
|
||||||
|
if a, ok := FromContext(ctx); !ok {
|
||||||
|
panic("scep authority is not in the context")
|
||||||
|
} else {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// AuthorityOptions required to create a new SCEP Authority.
|
// AuthorityOptions required to create a new SCEP Authority.
|
||||||
type AuthorityOptions struct {
|
type AuthorityOptions struct {
|
||||||
// Service provides the certificate chain, the signer and the decrypter to the Authority
|
// Service provides the certificate chain, the signer and the decrypter to the Authority
|
||||||
|
@ -163,7 +186,6 @@ func (a *Authority) GetCACertificates(ctx context.Context) ([]*x509.Certificate,
|
||||||
|
|
||||||
// DecryptPKIEnvelope decrypts an enveloped message
|
// DecryptPKIEnvelope decrypts an enveloped message
|
||||||
func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) error {
|
func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) error {
|
||||||
|
|
||||||
p7c, err := pkcs7.Parse(msg.P7.Content)
|
p7c, err := pkcs7.Parse(msg.P7.Content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error parsing pkcs7 content: %w", err)
|
return fmt.Errorf("error parsing pkcs7 content: %w", err)
|
||||||
|
@ -210,7 +232,6 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err
|
||||||
// SignCSR creates an x509.Certificate based on a CSR template and Cert Authority credentials
|
// SignCSR creates an x509.Certificate based on a CSR template and Cert Authority credentials
|
||||||
// returns a new PKIMessage with CertRep data
|
// returns a new PKIMessage with CertRep data
|
||||||
func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, msg *PKIMessage) (*PKIMessage, error) {
|
func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, msg *PKIMessage) (*PKIMessage, error) {
|
||||||
|
|
||||||
// TODO: intermediate storage of the request? In SCEP it's possible to request a csr/certificate
|
// TODO: intermediate storage of the request? In SCEP it's possible to request a csr/certificate
|
||||||
// to be signed, which can be performed asynchronously / out-of-band. In that case a client can
|
// to be signed, which can be performed asynchronously / out-of-band. In that case a client can
|
||||||
// poll for the status. It seems to be similar as what can happen in ACME, so might want to model
|
// poll for the status. It seems to be similar as what can happen in ACME, so might want to model
|
||||||
|
@ -432,7 +453,6 @@ func (a *Authority) CreateFailureResponse(ctx context.Context, csr *x509.Certifi
|
||||||
|
|
||||||
// MatchChallengePassword verifies a SCEP challenge password
|
// MatchChallengePassword verifies a SCEP challenge password
|
||||||
func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) {
|
func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) {
|
||||||
|
|
||||||
p, err := provisionerFromContext(ctx)
|
p, err := provisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|
Loading…
Reference in a new issue