diff --git a/acme/api/order.go b/acme/api/order.go index 3d22ec0f..e1adebb3 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -13,6 +13,7 @@ import ( "github.com/go-chi/chi" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/randutil" ) @@ -107,7 +108,8 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { for _, identifier := range nor.Identifiers { // TODO: gather all errors, so that we can build subproblems; include the nor.Validate() error here too, like in example? - err = prov.AuthorizeOrderIdentifier(ctx, identifier.Value) + orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value} + err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier) if err != nil { api.WriteError(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) return diff --git a/acme/common.go b/acme/common.go index 4b086dd7..9c5e732a 100644 --- a/acme/common.go +++ b/acme/common.go @@ -30,7 +30,7 @@ var clock Clock // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { - AuthorizeOrderIdentifier(ctx context.Context, identifier string) error + AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) AuthorizeRevoke(ctx context.Context, token string) error GetID() string @@ -45,7 +45,7 @@ type MockProvisioner struct { Merr error MgetID func() string MgetName func() string - MauthorizeOrderIdentifier func(ctx context.Context, identifier string) error + MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) MauthorizeRevoke func(ctx context.Context, token string) error MdefaultTLSCertDuration func() time.Duration @@ -61,7 +61,7 @@ func (m *MockProvisioner) GetName() string { } // AuthorizeOrderIdentifiers mock -func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error { +func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error { if m.MauthorizeOrderIdentifier != nil { return m.MauthorizeOrderIdentifier(ctx, identifier) } diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 2d5f74ff..9f8ef690 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -90,8 +90,6 @@ func (p *ACME) Init(config Config) (err error) { } // Initialize the x509 allow/deny policy engine - // TODO(hs): ensure no race conditions happen when reloading settings and requesting certs? - // TODO(hs): implement memoization strategy, so that reloading is not required when no changes were made to allow/deny? if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { return err } @@ -115,20 +113,22 @@ type ACMEIdentifier struct { Value string } -// AuthorizeOrderIdentifiers verifies the provisioner is authorized to issue a -// certificate for the Identifiers provided in an Order. -func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error { +// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a +// certificate for an ACME Order Identifier. +func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error { + // identifier is allowed if no policy is configured if p.x509Policy == nil { return nil } // assuming only valid identifiers (IP or DNS) are provided var err error - if ip := net.ParseIP(identifier); ip != nil { - _, err = p.x509Policy.IsIPAllowed(ip) - } else { - _, err = p.x509Policy.IsDNSAllowed(identifier) + switch identifier.Type { + case IP: + _, err = p.x509Policy.IsIPAllowed(net.ParseIP(identifier.Value)) + case DNS: + _, err = p.x509Policy.IsDNSAllowed(identifier.Value) } return err