diff --git a/acme/account.go b/acme/account.go index 027d7be1..2dd412db 100644 --- a/acme/account.go +++ b/acme/account.go @@ -7,6 +7,8 @@ import ( "time" "go.step.sm/crypto/jose" + + "github.com/smallstep/certificates/authority/policy" ) // Account is a subset of the internal account type containing only those @@ -43,15 +45,63 @@ func KeyToID(jwk *jose.JSONWebKey) (string, error) { return base64.RawURLEncoding.EncodeToString(kid), nil } +// PolicyNames contains ACME account level policy names +type PolicyNames struct { + DNSNames []string `json:"dns"` + IPRanges []string `json:"ips"` +} + +// X509Policy contains ACME account level X.509 policy +type X509Policy struct { + Allowed PolicyNames `json:"allow"` + Denied PolicyNames `json:"deny"` + AllowWildcardNames bool `json:"allowWildcardNames"` +} + +// Policy is an ACME Account level policy +type Policy struct { + X509 X509Policy `json:"x509"` +} + +func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions { + if p == nil { + return nil + } + return &policy.X509NameOptions{ + DNSDomains: p.X509.Allowed.DNSNames, + IPRanges: p.X509.Allowed.IPRanges, + } +} +func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions { + if p == nil { + return nil + } + return &policy.X509NameOptions{ + DNSDomains: p.X509.Denied.DNSNames, + IPRanges: p.X509.Denied.IPRanges, + } +} + +// AreWildcardNamesAllowed returns if wildcard names +// like *.example.com are allowed to be signed. +// Defaults to false. +func (p *Policy) AreWildcardNamesAllowed() bool { + if p == nil { + return false + } + return p.X509.AllowWildcardNames +} + // ExternalAccountKey is an ACME External Account Binding key. type ExternalAccountKey struct { ID string `json:"id"` ProvisionerID string `json:"provisionerID"` Reference string `json:"reference"` AccountID string `json:"-"` - KeyBytes []byte `json:"-"` + HmacKey []byte `json:"-"` CreatedAt time.Time `json:"createdAt"` BoundAt time.Time `json:"boundAt,omitempty"` + Policy *Policy `json:"policy,omitempty"` } // AlreadyBound returns whether this EAK is already bound to @@ -68,6 +118,6 @@ func (eak *ExternalAccountKey) BindTo(account *Account) error { } eak.AccountID = account.ID eak.BoundAt = time.Now() - eak.KeyBytes = []byte{} // clearing the key bytes; can only be used once + eak.HmacKey = []byte{} // clearing the key bytes; can only be used once return nil } diff --git a/acme/account_test.go b/acme/account_test.go index 33524d87..edd1f5b0 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -7,8 +7,9 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/assert" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" ) func TestKeyToID(t *testing.T) { @@ -95,7 +96,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) { ID: "eakID", ProvisionerID: "provID", Reference: "ref", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, }, acct: &Account{ ID: "accountID", @@ -108,7 +109,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) { ID: "eakID", ProvisionerID: "provID", Reference: "ref", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, AccountID: "someAccountID", BoundAt: boundAt, }, @@ -138,7 +139,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) { assert.Equals(t, ae.Subproblems, tt.err.Subproblems) } else { assert.Equals(t, eak.AccountID, acct.ID) - assert.Equals(t, eak.KeyBytes, []byte{}) + assert.Equals(t, eak.HmacKey, []byte{}) assert.NotNil(t, eak.BoundAt) } }) diff --git a/acme/api/account.go b/acme/api/account.go index ade51aef..f6e18f90 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -131,8 +131,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { } if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response - err := eak.BindTo(acc) - if err != nil { + if err := eak.BindTo(acc); err != nil { render.Error(w, err) return } diff --git a/acme/api/account_test.go b/acme/api/account_test.go index 4c3404ec..a0161cb4 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -13,10 +13,12 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + + "go.step.sm/crypto/jose" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" ) var ( @@ -41,6 +43,19 @@ func newProv() acme.Provisioner { return p } +func newProvWithOptions(options *provisioner.Options) acme.Provisioner { + // Initialize provisioners + p := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + Options: options, + } + if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { + fmt.Printf("%v", err) + } + return p +} + func newACMEProv(t *testing.T) *provisioner.ACME { p := newProv() a, ok := p.(*provisioner.ACME) @@ -50,6 +65,15 @@ func newACMEProv(t *testing.T) *provisioner.ACME { return a } +func newACMEProvWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME { + p := newProvWithOptions(options) + a, ok := p.(*provisioner.ACME) + if !ok { + t.Fatal("not a valid ACME provisioner") + } + return a +} + func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) { signer, err := jose.NewSigner( jose.SigningKey{ @@ -558,7 +582,7 @@ func TestHandler_NewAccount(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), } return test{ @@ -735,7 +759,7 @@ func TestHandler_NewAccount(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, diff --git a/acme/api/eab.go b/acme/api/eab.go index 3660d066..84be6453 100644 --- a/acme/api/eab.go +++ b/acme/api/eab.go @@ -4,8 +4,9 @@ import ( "context" "encoding/json" - "github.com/smallstep/certificates/acme" "go.step.sm/crypto/jose" + + "github.com/smallstep/certificates/acme" ) // ExternalAccountBinding represents the ACME externalAccountBinding JWS @@ -55,11 +56,19 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc return nil, acme.WrapErrorISE(err, "error retrieving external account key") } + if externalAccountKey == nil { + return nil, acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key") + } + + if len(externalAccountKey.HmacKey) == 0 { + return nil, acme.NewError(acme.ErrorServerInternalType, "external account binding key with id '%s' does not have secret bytes", keyID) + } + if externalAccountKey.AlreadyBound() { return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt) } - payload, err := eabJWS.Verify(externalAccountKey.KeyBytes) + payload, err := eabJWS.Verify(externalAccountKey.HmacKey) if err != nil { return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature") } diff --git a/acme/api/eab_test.go b/acme/api/eab_test.go index dce9f36d..c2725588 100644 --- a/acme/api/eab_test.go +++ b/acme/api/eab_test.go @@ -9,10 +9,12 @@ import ( "time" "github.com/pkg/errors" + + "go.step.sm/crypto/jose" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" ) func Test_keysAreEqual(t *testing.T) { @@ -154,7 +156,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: createdAt, }, nil }, @@ -168,7 +170,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: createdAt, }, err: nil, @@ -426,6 +428,114 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { err: acme.NewErrorISE("error retrieving external account key"), } }, + "fail/db.GetExternalAccountKey-nil": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return nil, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key"), + } + }, + "fail/db.GetExternalAccountKey-no-keybytes": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + createdAt := time.Now() + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + CreatedAt: createdAt, + AccountID: "some-account-id", + HmacKey: []byte{}, + }, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorServerInternalType, "external account binding key with id 'eakID' does not have secret bytes"), + } + }, "fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) @@ -520,6 +630,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { Reference: "testeak", CreatedAt: createdAt, AccountID: "some-account-id", + HmacKey: []byte{1, 3, 3, 7}, BoundAt: boundAt, }, nil }, @@ -575,7 +686,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 2, 3, 4}, + HmacKey: []byte{1, 2, 3, 4}, CreatedAt: time.Now(), }, nil }, @@ -633,7 +744,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, @@ -688,7 +799,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, @@ -744,7 +855,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ID: "eakID", ProvisionerID: provID, Reference: "testeak", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, @@ -787,7 +898,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { } else { assert.NotNil(t, tc.eak) assert.Equals(t, got.ID, tc.eak.ID) - assert.Equals(t, got.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, got.HmacKey, tc.eak.HmacKey) assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, got.Reference, tc.eak.Reference) assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt) diff --git a/acme/api/order.go b/acme/api/order.go index 99eb0e95..c37285d2 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -16,6 +16,8 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/authority/provisioner" ) // NewOrderRequest represents the body for a NewOrder request. @@ -37,6 +39,8 @@ func (n *NewOrderRequest) Validate() error { if id.Type == acme.IP && net.ParseIP(id.Value) == nil { return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value) } + // TODO(hs): add some validations for DNS domains? + // TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1 } return nil } @@ -85,6 +89,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } + var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, @@ -97,6 +102,48 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } + // TODO(hs): gather all errors, so that we can build one response with ACME subproblems + // include the nor.Validate() error here too, like in the example in the ACME RFC? + + acmeProv, err := acmeProvisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } + + var eak *acme.ExternalAccountKey + if acmeProv.RequireEAB { + if eak, err = h.db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil { + render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key")) + return + } + } + + acmePolicy, err := newACMEPolicyEngine(eak) + if err != nil { + render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine")) + return + } + + for _, identifier := range nor.Identifiers { + // evaluate the ACME account level policy + if err = isIdentifierAllowed(acmePolicy, identifier); err != nil { + render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + return + } + // evaluate the provisioner level policy + orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value} + if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil { + render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + return + } + // evaluate the authority level policy + if err = h.ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil { + render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + return + } + } + now := clock.Now() // New order. o := &acme.Order{ @@ -147,6 +194,20 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, o, http.StatusCreated) } +func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error { + if acmePolicy == nil { + return nil + } + return acmePolicy.AreSANsAllowed([]string{identifier.Value}) +} + +func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error) { + if eak == nil { + return nil, nil + } + return policy.NewX509PolicyEngine(eak.Policy) +} + func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { if strings.HasPrefix(az.Identifier.Value, "*.") { az.Wildcard = true diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 1ce034e7..35abab65 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -16,9 +16,13 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + + "go.step.sm/crypto/pemutil" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "go.step.sm/crypto/pemutil" + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/authority/provisioner" ) func TestNewOrderRequest_Validate(t *testing.T) { @@ -667,6 +671,7 @@ func TestHandler_NewOrder(t *testing.T) { baseURL.String(), escProvName) type test struct { + ca acme.CertificateAuthority db acme.DB ctx context.Context nor *NewOrderRequest @@ -756,6 +761,222 @@ func TestHandler_NewOrder(t *testing.T) { err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), } }, + "fail/acmeProvisionerFromContext-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, &acme.MockProvisioner{}) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 500, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, errors.New("force") + }, + }, + err: acme.NewErrorISE("error retrieving external account binding key: force"), + } + }, + "fail/db.GetExternalAccountKeyByAccountID-error": func(t *testing.T) test { + acmeProv := newACMEProv(t) + acmeProv.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 500, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, errors.New("force") + }, + }, + err: acme.NewErrorISE("error retrieving external account binding key: force"), + } + }, + "fail/newACMEPolicyEngine-error": func(t *testing.T) test { + acmeProv := newACMEProv(t) + acmeProv.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 500, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return &acme.ExternalAccountKey{ + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"**.local"}, + }, + }, + }, + }, nil + }, + }, + err: acme.NewErrorISE("error creating ACME policy engine"), + } + }, + "fail/isIdentifierAllowed-error": func(t *testing.T) test { + acmeProv := newACMEProv(t) + acmeProv.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return &acme.ExternalAccountKey{ + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.local"}, + }, + }, + }, + }, nil + }, + }, + err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"), + } + }, + "fail/prov.AuthorizeOrderIdentifier-error": func(t *testing.T) test { + options := &provisioner.Options{ + X509: &provisioner.X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + } + provWithPolicy := newACMEProvWithOptions(t, options) + provWithPolicy.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + ca: &mockCA{}, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return &acme.ExternalAccountKey{ + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.internal"}, + }, + }, + }, + }, nil + }, + }, + err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"), + } + }, + "fail/ca.AreSANsAllowed-error": func(t *testing.T) test { + options := &provisioner.Options{ + X509: &provisioner.X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.internal"}, + }, + }, + } + provWithPolicy := newACMEProvWithOptions(t, options) + provWithPolicy.RequireEAB = true + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + return test{ + ctx: ctx, + statusCode: 400, + ca: &mockCA{ + MockAreSANsallowed: func(ctx context.Context, sans []string) error { + return errors.New("force: not authorized by authority") + }, + }, + db: &acme.MockDB{ + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return &acme.ExternalAccountKey{ + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.internal"}, + }, + }, + }, + }, nil + }, + }, + err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"), + } + }, "fail/error-h.newAuthorization": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ @@ -771,6 +992,7 @@ func TestHandler_NewOrder(t *testing.T) { return test{ ctx: ctx, statusCode: 500, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.AccountID, "accID") @@ -780,6 +1002,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, ch.Value, "zap.internal") return errors.New("force") }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, err: acme.NewErrorISE("error creating challenge: force"), } @@ -804,6 +1031,7 @@ func TestHandler_NewOrder(t *testing.T) { return test{ ctx: ctx, statusCode: 500, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -849,6 +1077,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return errors.New("force") }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, err: acme.NewErrorISE("error creating order: force"), } @@ -876,6 +1109,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch chCount { @@ -945,6 +1179,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() @@ -991,6 +1230,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -1037,6 +1277,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() @@ -1083,6 +1328,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -1129,6 +1375,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() @@ -1174,6 +1425,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -1220,6 +1472,11 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { testBufferDur := 5 * time.Second @@ -1266,6 +1523,7 @@ func TestHandler_NewOrder(t *testing.T) { ctx: ctx, statusCode: 201, nor: nor, + ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { @@ -1312,11 +1570,120 @@ func TestHandler_NewOrder(t *testing.T) { assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, }, vr: func(t *testing.T, o *acme.Order) { testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, + } + }, + "ok/default-naf-nbf-with-policy": func(t *testing.T) test { + options := &provisioner.Options{ + X509: &provisioner.X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.internal"}, + }, + }, + } + provWithPolicy := newACMEProvWithOptions(t, options) + provWithPolicy.RequireEAB = true + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) + return test{ + ctx: ctx, + statusCode: 201, + nor: nor, + ca: &mockCA{}, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, acme.DNS01) + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, acme.HTTP01) + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, acme.TLSALPN01) + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return nil + }, + MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, prov.GetID(), provisionerID) + assert.Equals(t, "accID", accountID) + return nil, nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + now := clock.Now() + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + expNbf := now.Add(-defaultOrderBackdate) + expNaf := now.Add(prov.DefaultTLSCertDuration()) + assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) @@ -1334,7 +1701,7 @@ func TestHandler_NewOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 4ff54405..9b1fd6d5 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -24,14 +24,16 @@ import ( "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" + "golang.org/x/crypto/ocsp" + + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/x509util" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ocsp" ) // v is a utility function to return the pointer to an integer @@ -274,14 +276,22 @@ func jwsFinal(sha crypto.Hash, sig []byte, phead, payload string) ([]byte, error } type mockCA struct { - MockIsRevoked func(sn string) (bool, error) - MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error + MockIsRevoked func(sn string) (bool, error) + MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error + MockAreSANsallowed func(ctx context.Context, sans []string) error } func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return nil, nil } +func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error { + if m.MockAreSANsallowed != nil { + return m.MockAreSANsallowed(ctx, sans) + } + return nil +} + func (m *mockCA) IsRevoked(sn string) (bool, error) { if m.MockIsRevoked != nil { return m.MockIsRevoked(sn) diff --git a/acme/common.go b/acme/common.go index 0c9e83dc..e0d96deb 100644 --- a/acme/common.go +++ b/acme/common.go @@ -12,6 +12,7 @@ import ( // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + AreSANsAllowed(ctx context.Context, sans []string) error IsRevoked(sn string) (bool, error) Revoke(context.Context, *authority.RevokeOptions) error LoadProvisionerByName(string) (provisioner.Interface, error) @@ -30,6 +31,7 @@ var clock Clock // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { + AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) AuthorizeRevoke(ctx context.Context, token string) error GetID() string @@ -40,14 +42,15 @@ type Provisioner interface { // MockProvisioner for testing type MockProvisioner struct { - Mret1 interface{} - Merr error - MgetID func() string - MgetName func() string - MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) - MauthorizeRevoke func(ctx context.Context, token string) error - MdefaultTLSCertDuration func() time.Duration - MgetOptions func() *provisioner.Options + Mret1 interface{} + Merr error + MgetID func() string + MgetName func() string + MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error + MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) + MauthorizeRevoke func(ctx context.Context, token string) error + MdefaultTLSCertDuration func() time.Duration + MgetOptions func() *provisioner.Options } // GetName mock @@ -58,6 +61,14 @@ func (m *MockProvisioner) GetName() string { return m.Mret1.(string) } +// AuthorizeOrderIdentifiers mock +func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error { + if m.MauthorizeOrderIdentifier != nil { + return m.MauthorizeOrderIdentifier(ctx, identifier) + } + return m.Merr +} + // AuthorizeSign mock func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { if m.MauthorizeSign != nil { diff --git a/acme/db.go b/acme/db.go index 412276fd..b53cb397 100644 --- a/acme/db.go +++ b/acme/db.go @@ -23,6 +23,7 @@ type DB interface { GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) + GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error @@ -60,6 +61,7 @@ type MockDB struct { MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) + MockGetExternalAccountKeyByAccountID func(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error @@ -168,6 +170,16 @@ func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provision return m.MockRet1.(*ExternalAccountKey), m.MockError } +// GetExternalAccountKeyByAccountID mock +func (m *MockDB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) { + if m.MockGetExternalAccountKeyByAccountID != nil { + return m.MockGetExternalAccountKeyByAccountID(ctx, provisionerID, accountID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*ExternalAccountKey), m.MockError +} + // DeleteExternalAccountKey mock func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error { if m.MockDeleteExternalAccountKey != nil { diff --git a/acme/db/nosql/eab.go b/acme/db/nosql/eab.go index f9a24daf..e87aa9bc 100644 --- a/acme/db/nosql/eab.go +++ b/acme/db/nosql/eab.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" nosqlDB "github.com/smallstep/nosql" ) @@ -23,7 +24,7 @@ type dbExternalAccountKey struct { ProvisionerID string `json:"provisionerID"` Reference string `json:"reference"` AccountID string `json:"accountID,omitempty"` - KeyBytes []byte `json:"key"` + HmacKey []byte `json:"key"` CreatedAt time.Time `json:"createdAt"` BoundAt time.Time `json:"boundAt"` } @@ -72,7 +73,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer ID: keyID, ProvisionerID: provisionerID, Reference: reference, - KeyBytes: random, + HmacKey: random, CreatedAt: clock.Now(), } @@ -99,7 +100,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer ProvisionerID: dbeak.ProvisionerID, Reference: dbeak.Reference, AccountID: dbeak.AccountID, - KeyBytes: dbeak.KeyBytes, + HmacKey: dbeak.HmacKey, CreatedAt: dbeak.CreatedAt, BoundAt: dbeak.BoundAt, }, nil @@ -124,7 +125,7 @@ func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID st ProvisionerID: dbeak.ProvisionerID, Reference: dbeak.Reference, AccountID: dbeak.AccountID, - KeyBytes: dbeak.KeyBytes, + HmacKey: dbeak.HmacKey, CreatedAt: dbeak.CreatedAt, BoundAt: dbeak.BoundAt, }, nil @@ -191,7 +192,7 @@ func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor } keys = append(keys, &acme.ExternalAccountKey{ ID: eak.ID, - KeyBytes: eak.KeyBytes, + HmacKey: eak.HmacKey, ProvisionerID: eak.ProvisionerID, Reference: eak.Reference, AccountID: eak.AccountID, @@ -226,6 +227,10 @@ func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerI return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID) } +func (db *DB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { + return nil, nil +} + func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { externalAccountKeyMutex.Lock() defer externalAccountKeyMutex.Unlock() @@ -252,7 +257,7 @@ func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string ProvisionerID: eak.ProvisionerID, Reference: eak.Reference, AccountID: eak.AccountID, - KeyBytes: eak.KeyBytes, + HmacKey: eak.HmacKey, CreatedAt: eak.CreatedAt, BoundAt: eak.BoundAt, } diff --git a/acme/db/nosql/eab_test.go b/acme/db/nosql/eab_test.go index 568500e9..525afa72 100644 --- a/acme/db/nosql/eab_test.go +++ b/acme/db/nosql/eab_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/pkg/errors" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" certdb "github.com/smallstep/certificates/db" @@ -32,7 +33,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -108,7 +109,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) { } } else if assert.Nil(t, tc.err) { assert.Equals(t, dbeak.ID, tc.dbeak.ID) - assert.Equals(t, dbeak.KeyBytes, tc.dbeak.KeyBytes) + assert.Equals(t, dbeak.HmacKey, tc.dbeak.HmacKey) assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID) assert.Equals(t, dbeak.Reference, tc.dbeak.Reference) assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt) @@ -136,7 +137,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -154,7 +155,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, } @@ -179,7 +180,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { ProvisionerID: "aDifferentProvID", Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -197,7 +198,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: "ref", AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"), @@ -225,7 +226,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) { } } else if assert.Nil(t, tc.err) { assert.Equals(t, eak.ID, tc.eak.ID) - assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, eak.HmacKey, tc.eak.HmacKey) assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, eak.Reference, tc.eak.Reference) assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) @@ -255,7 +256,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -288,7 +289,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, err: nil, @@ -392,7 +393,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) { assert.Equals(t, eak.AccountID, tc.eak.AccountID) assert.Equals(t, eak.BoundAt, tc.eak.BoundAt) assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) - assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, eak.HmacKey, tc.eak.HmacKey) assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, eak.Reference, tc.eak.Reference) } @@ -420,7 +421,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b1, err := json.Marshal(dbeak1) @@ -430,7 +431,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b2, err := json.Marshal(dbeak2) @@ -440,7 +441,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: "aDifferentProvID", Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b3, err := json.Marshal(dbeak3) @@ -513,7 +514,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, { @@ -521,7 +522,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, }, @@ -598,7 +599,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) { assert.Equals(t, "", nextCursor) for i, eak := range eaks { assert.Equals(t, eak.ID, tc.eaks[i].ID) - assert.Equals(t, eak.KeyBytes, tc.eaks[i].KeyBytes) + assert.Equals(t, eak.HmacKey, tc.eaks[i].HmacKey) assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID) assert.Equals(t, eak.Reference, tc.eaks[i].Reference) assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt) @@ -627,7 +628,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -707,7 +708,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: "aDifferentProvID", Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -730,7 +731,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -780,7 +781,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -830,7 +831,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ @@ -953,7 +954,7 @@ func TestDB_CreateExternalAccountKey(t *testing.T) { assert.Equals(t, string(key), dbeak.ID) assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID) assert.Equals(t, eak.Reference, dbeak.Reference) - assert.Equals(t, 32, len(dbeak.KeyBytes)) + assert.Equals(t, 32, len(dbeak.HmacKey)) assert.False(t, dbeak.CreatedAt.IsZero()) assert.Equals(t, dbeak.AccountID, eak.AccountID) assert.True(t, dbeak.BoundAt.IsZero()) @@ -1078,7 +1079,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) @@ -1096,7 +1097,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } return test{ @@ -1120,7 +1121,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { assert.Equals(t, dbNew.AccountID, dbeak.AccountID) assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt) assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt) - assert.Equals(t, dbNew.KeyBytes, dbeak.KeyBytes) + assert.Equals(t, dbNew.HmacKey, dbeak.HmacKey) return nu, true, nil }, }, @@ -1148,7 +1149,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: "aDifferentProvID", Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(newDBEAK) @@ -1174,7 +1175,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(newDBEAK) @@ -1200,7 +1201,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { ProvisionerID: provID, Reference: ref, AccountID: "", - KeyBytes: []byte{1, 3, 3, 7}, + HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(newDBEAK) @@ -1237,7 +1238,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) { assert.Equals(t, dbeak.AccountID, tc.eak.AccountID) assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt) - assert.Equals(t, dbeak.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, dbeak.HmacKey, tc.eak.HmacKey) } }) } diff --git a/acme/order_test.go b/acme/order_test.go index 493b40b7..f1f28e40 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -268,6 +268,7 @@ func TestOrder_UpdateStatus(t *testing.T) { type mockSignAuth struct { sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + areSANsAllowed func(ctx context.Context, sans []string) error loadProvisionerByName func(string) (provisioner.Interface, error) ret1, ret2 interface{} err error @@ -282,6 +283,13 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error { + if m.areSANsAllowed != nil { + return m.areSANsAllowed(ctx, sans) + } + return m.err +} + func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) { if m.loadProvisionerByName != nil { return m.loadProvisionerByName(name) diff --git a/api/read/read.go b/api/read/read.go index de92c5d7..72530b8c 100644 --- a/api/read/read.go +++ b/api/read/read.go @@ -3,16 +3,20 @@ package read import ( "encoding/json" + "errors" "io" + "net/http" + "strings" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) // JSON reads JSON from the request body and stores it in the value -// pointed by v. +// pointed to by v. func JSON(r io.Reader, v interface{}) error { if err := json.NewDecoder(r).Decode(v); err != nil { return errs.BadRequestErr(err, "error decoding json") @@ -21,11 +25,42 @@ func JSON(r io.Reader, v interface{}) error { } // ProtoJSON reads JSON from the request body and stores it in the value -// pointed by v. +// pointed to by m. func ProtoJSON(r io.Reader, m proto.Message) error { data, err := io.ReadAll(r) if err != nil { return errs.BadRequestErr(err, "error reading request body") } - return protojson.Unmarshal(data, m) + + switch err := protojson.Unmarshal(data, m); { + case errors.Is(err, proto.Error): + return badProtoJSONError(err.Error()) + default: + return err + } +} + +// badProtoJSONError is an error type that is returned by ProtoJSON +// when a proto message cannot be unmarshaled. Usually this is caused +// by an error in the request body. +type badProtoJSONError string + +// Error implements error for badProtoJSONError +func (e badProtoJSONError) Error() string { + return string(e) +} + +// Render implements render.RenderableError for badProtoJSONError +func (e badProtoJSONError) Render(w http.ResponseWriter) { + v := struct { + Type string `json:"type"` + Detail string `json:"detail"` + Message string `json:"message"` + }{ + Type: "badRequest", + Detail: "bad request", + // trim the proto prefix for the message + Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")), + } + render.JSONStatus(w, v, http.StatusBadRequest) } diff --git a/api/read/read_test.go b/api/read/read_test.go index f2eff1bc..72100584 100644 --- a/api/read/read_test.go +++ b/api/read/read_test.go @@ -1,10 +1,21 @@ package read import ( + "encoding/json" + "errors" "io" + "net/http" + "net/http/httptest" "reflect" "strings" "testing" + "testing/iotest" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "go.step.sm/linkedca" "github.com/smallstep/certificates/errs" ) @@ -44,3 +55,110 @@ func TestJSON(t *testing.T) { }) } } + +func TestProtoJSON(t *testing.T) { + + p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import? + + type args struct { + r io.Reader + m proto.Message + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "fail/io.ReadAll", + args: args{ + r: iotest.ErrReader(errors.New("read error")), + m: p, + }, + wantErr: true, + }, + { + name: "fail/proto", + args: args{ + r: strings.NewReader(`{?}`), + m: p, + }, + wantErr: true, + }, + { + name: "ok", + args: args{ + r: strings.NewReader(`{"x509":{}}`), + m: p, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ProtoJSON(tt.args.r, tt.args.m) + if (err != nil) != tt.wantErr { + t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + switch err.(type) { + case badProtoJSONError: + assert.Contains(t, err.Error(), "syntax error") + case *errs.Error: + var ee *errs.Error + if errors.As(err, &ee) { + assert.Equal(t, http.StatusBadRequest, ee.Status) + } + } + return + } + + assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m)) + assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m)) + }) + } +} + +func Test_badProtoJSONError_Render(t *testing.T) { + tests := []struct { + name string + e badProtoJSONError + expected string + }{ + { + name: "bad proto normal space", + e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"), + expected: "syntax error (line 1:2): invalid value ?", + }, + { + name: "bad proto non breaking space", + e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"), + expected: "syntax error (line 1:2): invalid value ?", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + w := httptest.NewRecorder() + tt.e.Render(w) + res := w.Result() + defer res.Body.Close() + + data, err := io.ReadAll(res.Body) + assert.NoError(t, err) + + v := struct { + Type string `json:"type"` + Detail string `json:"detail"` + Message string `json:"message"` + }{} + + assert.NoError(t, json.Unmarshal(data, &v)) + assert.Equal(t, "badRequest", v.Type) + assert.Equal(t, "bad request", v.Detail) + assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message) + + }) + } +} diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 21a7229d..da491dfe 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -1,22 +1,15 @@ package api import ( - "context" "fmt" "net/http" - "github.com/go-chi/chi" - "go.step.sm/linkedca" + "google.golang.org/protobuf/types/known/timestamppb" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/authority/provisioner" -) - -const ( - // provisionerContextKey provisioner key - provisionerContextKey = ContextKey("provisioner") ) // CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests @@ -40,51 +33,24 @@ type GetExternalAccountKeysResponse struct { // requireEABEnabled is a middleware that ensures ACME EAB is enabled // before serving requests that act on ACME EAB credentials. -func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { +func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - provName := chi.URLParam(r, "provisionerName") - eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) - if err != nil { - render.Error(w, err) + prov := linkedca.MustProvisionerFromContext(ctx) + + acmeProvisioner := prov.GetDetails().GetACME() + if acmeProvisioner == nil { + render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName())) return } - if !eabEnabled { - render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) + + if !acmeProvisioner.RequireEab { + render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName())) return } - ctx = context.WithValue(ctx, provisionerContextKey, prov) - next(w, r.WithContext(ctx)) - } -} -// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME -// provisioner is set to true and thus has EAB enabled. -func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) { - var ( - p provisioner.Interface - err error - ) - if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil { - return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName) + next(w, r) } - - prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) - if err != nil { - return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID()) - } - - details := prov.GetDetails() - if details == nil { - return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID()) - } - - acmeProvisioner := details.GetACME() - if acmeProvisioner == nil { - return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID()) - } - - return acmeProvisioner.GetRequireEab(), prov, nil } type acmeAdminResponderInterface interface { @@ -115,3 +81,72 @@ func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r * func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } + +func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey { + + if k == nil { + return nil + } + + eak := &linkedca.EABKey{ + Id: k.ID, + HmacKey: k.HmacKey, + Provisioner: k.ProvisionerID, + Reference: k.Reference, + Account: k.AccountID, + CreatedAt: timestamppb.New(k.CreatedAt), + BoundAt: timestamppb.New(k.BoundAt), + } + + if k.Policy != nil { + eak.Policy = &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{}, + Deny: &linkedca.X509Names{}, + }, + } + eak.Policy.X509.Allow.Dns = k.Policy.X509.Allowed.DNSNames + eak.Policy.X509.Allow.Ips = k.Policy.X509.Allowed.IPRanges + eak.Policy.X509.Deny.Dns = k.Policy.X509.Denied.DNSNames + eak.Policy.X509.Deny.Ips = k.Policy.X509.Denied.IPRanges + eak.Policy.X509.AllowWildcardNames = k.Policy.X509.AllowWildcardNames + } + + return eak +} + +func linkedEAKToCertificates(k *linkedca.EABKey) *acme.ExternalAccountKey { + if k == nil { + return nil + } + + eak := &acme.ExternalAccountKey{ + ID: k.Id, + ProvisionerID: k.Provisioner, + Reference: k.Reference, + AccountID: k.Account, + HmacKey: k.HmacKey, + CreatedAt: k.CreatedAt.AsTime(), + BoundAt: k.BoundAt.AsTime(), + } + + if policy := k.GetPolicy(); policy != nil { + eak.Policy = &acme.Policy{} + if x509 := policy.GetX509(); x509 != nil { + eak.Policy.X509 = acme.X509Policy{} + if allow := x509.GetAllow(); allow != nil { + eak.Policy.X509.Allowed = acme.PolicyNames{} + eak.Policy.X509.Allowed.DNSNames = allow.Dns + eak.Policy.X509.Allowed.IPRanges = allow.Ips + } + if deny := x509.GetDeny(); deny != nil { + eak.Policy.X509.Denied = acme.PolicyNames{} + eak.Policy.X509.Denied.DNSNames = deny.Dns + eak.Policy.X509.Denied.IPRanges = deny.Ips + } + eak.Policy.X509.AllowWildcardNames = x509.AllowWildcardNames + } + } + + return eak +} diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go index 6ffe1418..3ff32763 100644 --- a/authority/admin/api/acme_test.go +++ b/authority/admin/api/acme_test.go @@ -4,20 +4,24 @@ import ( "bytes" "context" "encoding/json" - "errors" "io" "net/http" "net/http/httptest" + "reflect" "strings" "testing" + "time" "github.com/go-chi/chi" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/admin" ) func readProtoJSON(r io.ReadCloser, m proto.Message) error { @@ -32,106 +36,76 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error { func TestHandler_requireEABEnabled(t *testing.T) { type test struct { ctx context.Context - adminDB admin.DB - auth adminAuthority - next nextHTTP + next http.HandlerFunc err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ - "fail/h.provisionerHasEABEnabled": func(t *testing.T) test { - chiCtx := chi.NewRouteContext() - chiCtx.URLParams.Add("provisionerName", "provName") - ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return nil, errors.New("force") - }, + "fail/prov.GetDetails": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", } - err := admin.NewErrorISE("error loading provisioner provName: force") - err.Message = "error loading provisioner provName: force" + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'") + err.Message = "error getting ACME details for provisioner 'provName'" + return test{ + ctx: ctx, + err: err, + statusCode: 500, + } + }, + "fail/prov.GetDetails.GetACME": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{}, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'") + err.Message = "error getting ACME details for provisioner 'provName'" return test{ ctx: ctx, - auth: auth, err: err, statusCode: 500, } }, "ok/eab-disabled": func(t *testing.T) test { - chiCtx := chi.NewRouteContext() - chiCtx.URLParams.Add("provisionerName", "provName") - ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: false, }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "provName", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: &linkedca.ACMEProvisioner{ - RequireEab: false, - }, - }, - }, - }, nil + }, }, } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName") - err.Message = "ACME EAB not enabled for provisioner provName" + err.Message = "ACME EAB not enabled for provisioner 'provName'" return test{ ctx: ctx, - auth: auth, - adminDB: db, err: err, statusCode: 400, } }, "ok/eab-enabled": func(t *testing.T) test { - chiCtx := chi.NewRouteContext() - chiCtx.URLParams.Add("provisionerName", "provName") - ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: true, }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "provName", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: &linkedca.ACMEProvisioner{ - RequireEab: true, - }, - }, - }, - }, nil + }, }, } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ - ctx: ctx, - auth: auth, - adminDB: db, + ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { w.Write(nil) // mock response with status 200 }, @@ -143,13 +117,9 @@ func TestHandler_requireEABEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - acmeDB: nil, - } + h := &Handler{} - req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup + req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.requireEABEnabled(tc.next)(w, req) @@ -176,216 +146,6 @@ func TestHandler_requireEABEnabled(t *testing.T) { } } -func TestHandler_provisionerHasEABEnabled(t *testing.T) { - type test struct { - adminDB admin.DB - auth adminAuthority - provisionerName string - want bool - err *admin.Error - } - var tests = map[string]func(t *testing.T) test{ - "fail/auth.LoadProvisionerByName": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return nil, errors.New("force") - }, - } - return test{ - auth: auth, - provisionerName: "provName", - want: false, - err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), - } - }, - "fail/db.GetProvisioner": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return nil, errors.New("force") - }, - } - return test{ - auth: auth, - adminDB: db, - provisionerName: "provName", - want: false, - err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), - } - }, - "fail/prov.GetDetails": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "provName", - Details: nil, - }, nil - }, - } - return test{ - auth: auth, - adminDB: db, - provisionerName: "provName", - want: false, - err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), - } - }, - "fail/details.GetACME": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "provName", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "provName", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: nil, - }, - }, - }, nil - }, - } - return test{ - auth: auth, - adminDB: db, - provisionerName: "provName", - want: false, - err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), - } - }, - "ok/eab-disabled": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "eab-disabled", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "eab-disabled", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: &linkedca.ACMEProvisioner{ - RequireEab: false, - }, - }, - }, - }, nil - }, - } - return test{ - adminDB: db, - auth: auth, - provisionerName: "eab-disabled", - want: false, - } - }, - "ok/eab-enabled": func(t *testing.T) test { - auth := &mockAdminAuthority{ - MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { - assert.Equals(t, "eab-enabled", name) - return &provisioner.MockProvisioner{ - MgetID: func() string { - return "provID" - }, - }, nil - }, - } - db := &admin.MockDB{ - MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { - assert.Equals(t, "provID", id) - return &linkedca.Provisioner{ - Id: "provID", - Name: "eab-enabled", - Details: &linkedca.ProvisionerDetails{ - Data: &linkedca.ProvisionerDetails_ACME{ - ACME: &linkedca.ACMEProvisioner{ - RequireEab: true, - }, - }, - }, - }, nil - }, - } - return test{ - adminDB: db, - auth: auth, - provisionerName: "eab-enabled", - want: true, - } - }, - } - for name, prep := range tests { - tc := prep(t) - t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - acmeDB: nil, - } - got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName) - if (err != nil) != (tc.err != nil) { - t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err) - return - } - if tc.err != nil { - assert.Type(t, &linkedca.Provisioner{}, prov) - assert.Type(t, &admin.Error{}, err) - adminError, _ := err.(*admin.Error) - assert.Equals(t, tc.err.Type, adminError.Type) - assert.Equals(t, tc.err.Status, adminError.Status) - assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode()) - assert.Equals(t, tc.err.Message, adminError.Message) - assert.Equals(t, tc.err.Detail, adminError.Detail) - return - } - if got != tc.want { - t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want) - } - }) - } -} - func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) { type fields struct { Reference string @@ -585,3 +345,206 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) { }) } } + +func Test_eakToLinked(t *testing.T) { + tests := []struct { + name string + k *acme.ExternalAccountKey + want *linkedca.EABKey + }{ + { + name: "no-key", + k: nil, + want: nil, + }, + { + name: "no-policy", + k: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: nil, + }, + want: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: nil, + }, + }, + { + name: "with-policy", + k: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/24"}, + }, + Denied: acme.PolicyNames{ + DNSNames: []string{"badhost.local"}, + IPRanges: []string{"10.0.0.30"}, + }, + }, + }, + }, + want: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/24"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"10.0.0.30"}, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := eakToLinked(tt.k); !reflect.DeepEqual(got, tt.want) { + t.Errorf("eakToLinked() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_linkedEAKToCertificates(t *testing.T) { + tests := []struct { + name string + k *linkedca.EABKey + want *acme.ExternalAccountKey + }{ + { + name: "no-key", + k: nil, + want: nil, + }, + { + name: "no-policy", + k: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: nil, + }, + want: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: nil, + }, + }, + { + name: "no-x509-policy", + k: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: &linkedca.Policy{}, + }, + want: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: &acme.Policy{}, + }, + }, + { + name: "with-x509-policy", + k: &linkedca.EABKey{ + Id: "keyID", + Provisioner: "provID", + HmacKey: []byte{1, 3, 3, 7}, + Reference: "ref", + Account: "accID", + CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), + BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), + Policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/24"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"10.0.0.30"}, + }, + AllowWildcardNames: true, + }, + }, + }, + want: &acme.ExternalAccountKey{ + ID: "keyID", + ProvisionerID: "provID", + Reference: "ref", + AccountID: "accID", + HmacKey: []byte{1, 3, 3, 7}, + CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), + BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), + Policy: &acme.Policy{ + X509: acme.X509Policy{ + Allowed: acme.PolicyNames{ + DNSNames: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/24"}, + }, + Denied: acme.PolicyNames{ + DNSNames: []string{"badhost.local"}, + IPRanges: []string{"10.0.0.30"}, + }, + AllowWildcardNames: true, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := linkedEAKToCertificates(tt.k); !reflect.DeepEqual(got, tt.want) { + t.Errorf("linkedEAKToCertificates() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 5e4b9c30..a033b1a5 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -29,6 +29,10 @@ type adminAuthority interface { LoadProvisionerByID(id string) (provisioner.Interface, error) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error RemoveProvisioner(ctx context.Context, id string) error + GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) + CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + RemoveAuthorityPolicy(ctx context.Context) error } // CreateAdminRequest represents the body for a CreateAdmin request. diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index 8d223b52..cc77ef77 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -14,11 +14,13 @@ import ( "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" - "google.golang.org/protobuf/types/known/timestamppb" ) type mockAdminAuthority struct { @@ -37,6 +39,11 @@ type mockAdminAuthority struct { MockLoadProvisionerByID func(id string) (provisioner.Interface, error) MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error MockRemoveProvisioner func(ctx context.Context, id string) error + + MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) + MockCreateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + MockUpdateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) + MockRemoveAuthorityPolicy func(ctx context.Context) error } func (m *mockAdminAuthority) IsAdminAPIEnabled() bool { @@ -130,6 +137,34 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e return m.MockErr } +func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + if m.MockGetAuthorityPolicy != nil { + return m.MockGetAuthorityPolicy(ctx) + } + return m.MockRet1.(*linkedca.Policy), m.MockErr +} + +func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + if m.MockCreateAuthorityPolicy != nil { + return m.MockCreateAuthorityPolicy(ctx, adm, policy) + } + return m.MockRet1.(*linkedca.Policy), m.MockErr +} + +func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + if m.MockUpdateAuthorityPolicy != nil { + return m.MockUpdateAuthorityPolicy(ctx, adm, policy) + } + return m.MockRet1.(*linkedca.Policy), m.MockErr +} + +func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error { + if m.MockRemoveAuthorityPolicy != nil { + return m.MockRemoveAuthorityPolicy(ctx) + } + return m.MockErr +} + func TestCreateAdminRequest_Validate(t *testing.T) { type fields struct { Subject string diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index 99e74c88..eb52ad58 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -1,6 +1,8 @@ package api import ( + "net/http" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" @@ -8,30 +10,53 @@ import ( // Handler is the Admin API request handler. type Handler struct { - adminDB admin.DB - auth adminAuthority - acmeDB acme.DB - acmeResponder acmeAdminResponderInterface + adminDB admin.DB + auth adminAuthority + acmeDB acme.DB + acmeResponder acmeAdminResponderInterface + policyResponder policyAdminResponderInterface } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler { +func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler { return &Handler{ - auth: auth, - adminDB: adminDB, - acmeDB: acmeDB, - acmeResponder: acmeResponder, + auth: auth, + adminDB: adminDB, + acmeDB: acmeDB, + acmeResponder: acmeResponder, + policyResponder: policyResponder, } } // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { - authnz := func(next nextHTTP) nextHTTP { + + authnz := func(next http.HandlerFunc) http.HandlerFunc { return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) } - requireEABEnabled := func(next nextHTTP) nextHTTP { - return h.requireEABEnabled(next) + enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { + return h.checkAction(next, true) + } + + disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { + return h.checkAction(next, false) + } + + acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return authnz(h.loadProvisionerByName(h.requireEABEnabled(next))) + } + + authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return authnz(enabledInStandalone(next)) + } + + provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return authnz(disabledInStandalone(h.loadProvisionerByName(next))) + } + + acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return authnz(disabledInStandalone(h.loadProvisionerByName(h.requireEABEnabled(h.loadExternalAccountKey(next))))) } // Provisioners @@ -49,8 +74,31 @@ func (h *Handler) Route(r api.Router) { r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin)) // ACME External Account Binding Keys - r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) - r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) - r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey))) - r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey))) + r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys)) + r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys)) + r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.CreateExternalAccountKey)) + r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(h.acmeResponder.DeleteExternalAccountKey)) + + // Policy - Authority + r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(h.policyResponder.GetAuthorityPolicy)) + r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(h.policyResponder.CreateAuthorityPolicy)) + r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(h.policyResponder.UpdateAuthorityPolicy)) + r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(h.policyResponder.DeleteAuthorityPolicy)) + + // Policy - Provisioner + r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.GetProvisionerPolicy)) + r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.CreateProvisionerPolicy)) + r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.UpdateProvisionerPolicy)) + r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.DeleteProvisionerPolicy)) + + // Policy - ACME Account + r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy)) + r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy)) + r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy)) + r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy)) + r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy)) + r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy)) + r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy)) + r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy)) + } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index b57dd6eb..24adfdf2 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -1,18 +1,23 @@ package api import ( - "context" + "errors" "net/http" + "github.com/go-chi/chi" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/admin/db/nosql" + "github.com/smallstep/certificates/authority/provisioner" ) -type nextHTTP = func(http.ResponseWriter, *http.Request) - // requireAPIEnabled is a middleware that ensures the Administration API // is enabled before servicing requests. -func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { +func (h *Handler) requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if !h.auth.IsAdminAPIEnabled() { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, @@ -24,8 +29,9 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { } // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. -func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { +func (h *Handler) extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + tok := r.Header.Get("Authorization") if tok == "" { render.Error(w, admin.NewError(admin.ErrorUnauthorizedType, @@ -39,16 +45,102 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return } - ctx := context.WithValue(r.Context(), adminContextKey, adm) + ctx := linkedca.NewContextWithAdmin(r.Context(), adm) next(w, r.WithContext(ctx)) } } -// ContextKey is the key type for storing and searching for ACME request -// essentials in the context of a request. -type ContextKey string +// loadProvisionerByName is a middleware that searches for a provisioner +// by name and stores it in the context. +func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { -const ( - // adminContextKey account key - adminContextKey = ContextKey("admin") -) + ctx := r.Context() + name := chi.URLParam(r, "provisionerName") + var ( + p provisioner.Interface + err error + ) + + // TODO(hs): distinguish 404 vs. 500 + if p, err = h.auth.LoadProvisionerByName(name); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + return + } + + prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) + if err != nil { + render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name)) + return + } + + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + next(w, r.WithContext(ctx)) + } +} + +// checkAction checks if an action is supported in standalone or not +func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + + // actions allowed in standalone mode are always supported + if supportedInStandalone { + next(w, r) + return + } + + // when an action is not supported in standalone mode and when + // using a nosql.DB backend, actions are not supported + if _, ok := h.adminDB.(*nosql.DB); ok { + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, + "operation not supported in standalone mode")) + return + } + + // continue to next http handler + next(w, r) + } +} + +// loadExternalAccountKey is a middleware that searches for an ACME +// External Account Key by reference or keyID and stores it in the context. +func (h *Handler) loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + prov := linkedca.MustProvisionerFromContext(ctx) + + reference := chi.URLParam(r, "reference") + keyID := chi.URLParam(r, "keyID") + + var ( + eak *acme.ExternalAccountKey + err error + ) + + if keyID != "" { + eak, err = h.acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID) + } else { + eak, err = h.acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference) + } + + if err != nil { + if errors.Is(err, acme.ErrNotFound) { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) + return + } + render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key")) + return + } + + if eak == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) + return + } + + linkedEAK := eakToLinked(eak) + + ctx = linkedca.NewContextWithExternalAccountKey(ctx, linkedEAK) + + next(w, r.WithContext(ctx)) + } +} diff --git a/authority/admin/api/middleware_test.go b/authority/admin/api/middleware_test.go index 7fb4671a..42caed9a 100644 --- a/authority/admin/api/middleware_test.go +++ b/authority/admin/api/middleware_test.go @@ -4,25 +4,32 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" "testing" "time" + "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/admin" - "go.step.sm/linkedca" "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/admin/db/nosql" + "github.com/smallstep/certificates/authority/provisioner" ) func TestHandler_requireAPIEnabled(t *testing.T) { type test struct { ctx context.Context auth adminAuthority - next nextHTTP + next http.HandlerFunc err *admin.Error statusCode int } @@ -102,7 +109,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { ctx context.Context auth adminAuthority req *http.Request - next nextHTTP + next http.HandlerFunc err *admin.Error statusCode int } @@ -152,7 +159,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { req.Header["Authorization"] = []string{"token"} createdAt := time.Now() var deletedAt time.Time - admin := &linkedca.Admin{ + adm := &linkedca.Admin{ Id: "adminID", AuthorityId: "authorityID", Subject: "admin", @@ -164,20 +171,15 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { auth := &mockAdminAuthority{ MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { assert.Equals(t, "token", token) - return admin, nil + return adm, nil }, } next := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin - adm, ok := a.(*linkedca.Admin) - if !ok { - t.Errorf("expected *linkedca.Admin; got %T", a) - return - } + adm := linkedca.MustAdminFromContext(ctx) // verifying that the context now has a linkedca.Admin opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} - if !cmp.Equal(admin, adm, opts...) { - t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...)) + if !cmp.Equal(adm, adm, opts...) { + t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...)) } w.Write(nil) // mock response with status 200 } @@ -223,3 +225,461 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { }) } } + +func TestHandler_loadProvisionerByName(t *testing.T) { + type test struct { + adminDB admin.DB + auth adminAuthority + ctx context.Context + next http.HandlerFunc + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + err := admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName") + err.Message = "error loading provisioner provName: force" + return test{ + ctx: ctx, + auth: auth, + statusCode: 500, + err: err, + } + }, + "fail/db.GetProvisioner": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + err := admin.WrapErrorISE(errors.New("force"), "error retrieving provisioner provName") + err.Message = "error retrieving provisioner provName: force" + return test{ + ctx: ctx, + auth: auth, + adminDB: db, + statusCode: 500, + err: err, + } + }, + "ok": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + }, nil + }, + } + return test{ + ctx: ctx, + auth: auth, + adminDB: db, + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + prov := linkedca.MustProvisionerFromContext(r.Context()) + assert.NotNil(t, prov) + assert.Equals(t, "provID", prov.GetId()) + assert.Equals(t, "provName", prov.GetName()) + w.Write(nil) // mock response with status 200 + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + adminDB: tc.adminDB, + } + + req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup + req = req.WithContext(tc.ctx) + + w := httptest.NewRecorder() + h.loadProvisionerByName(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + }) + } +} + +func TestHandler_checkAction(t *testing.T) { + type test struct { + adminDB admin.DB + next http.HandlerFunc + supportedInStandalone bool + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "standalone-nosql-supported": func(t *testing.T) test { + return test{ + supportedInStandalone: true, + adminDB: &nosql.DB{}, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(nil) // mock response with status 200 + }, + statusCode: 200, + } + }, + "standalone-nosql-not-supported": func(t *testing.T) test { + err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode") + err.Message = "operation not supported in standalone mode" + return test{ + supportedInStandalone: false, + adminDB: &nosql.DB{}, + statusCode: 501, + err: err, + } + }, + "standalone-no-nosql-not-supported": func(t *testing.T) test { + err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported") + err.Message = "operation not supported" + return test{ + supportedInStandalone: false, + adminDB: &admin.MockDB{}, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(nil) // mock response with status 200 + }, + statusCode: 200, + err: err, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + + adminDB: tc.adminDB, + } + + req := httptest.NewRequest("GET", "/foo", nil) + w := httptest.NewRecorder() + h.checkAction(tc.next, tc.supportedInStandalone)(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + }) + } +} + +func TestHandler_loadExternalAccountKey(t *testing.T) { + type test struct { + ctx context.Context + acmeDB acme.DB + next http.HandlerFunc + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/keyID-not-found-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("keyID", "key") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "key", keyID) + return nil, acme.ErrNotFound + }, + }, + err: err, + statusCode: 404, + } + }, + "fail/keyID-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("keyID", "key") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key") + err.Message = "error retrieving ACME External Account Key: force" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "key", keyID) + return nil, errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/reference-not-found-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("reference", "ref") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "ref", reference) + return nil, acme.ErrNotFound + }, + }, + err: err, + statusCode: 404, + } + }, + "fail/reference-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("reference", "ref") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key") + err.Message = "error retrieving ACME External Account Key: force" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "ref", reference) + return nil, errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/no-key": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("reference", "ref") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "ref", reference) + return nil, nil + }, + }, + err: err, + statusCode: 404, + } + }, + "ok/keyID": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("keyID", "eakID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + createdAt := time.Now().Add(-1 * time.Hour) + var boundAt time.Time + eak := &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + CreatedAt: createdAt, + BoundAt: boundAt, + } + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "eakID", keyID) + return eak, nil + }, + }, + next: func(w http.ResponseWriter, r *http.Request) { + contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context()) + assert.NotNil(t, eak) + exp := &linkedca.EABKey{ + Id: "eakID", + Provisioner: "provID", + CreatedAt: timestamppb.New(createdAt), + BoundAt: timestamppb.New(boundAt), + } + assert.Equals(t, exp, contextEAK) + w.Write(nil) // mock response with status 200 + }, + err: nil, + statusCode: 200, + } + }, + "ok/reference": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("reference", "ref") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") + err.Message = "ACME External Account Key not found" + createdAt := time.Now().Add(-1 * time.Hour) + var boundAt time.Time + eak := &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + Reference: "ref", + CreatedAt: createdAt, + BoundAt: boundAt, + } + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "ref", reference) + return eak, nil + }, + }, + next: func(w http.ResponseWriter, r *http.Request) { + contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context()) + assert.NotNil(t, eak) + exp := &linkedca.EABKey{ + Id: "eakID", + Provisioner: "provID", + Reference: "ref", + CreatedAt: timestamppb.New(createdAt), + BoundAt: timestamppb.New(boundAt), + } + assert.Equals(t, exp, contextEAK) + w.Write(nil) // mock response with status 200 + }, + err: nil, + statusCode: 200, + } + }, + } + + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + acmeDB: tc.acmeDB, + } + + req := httptest.NewRequest("GET", "/foo", nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.loadExternalAccountKey(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + }) + } +} diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go new file mode 100644 index 00000000..6af1104a --- /dev/null +++ b/authority/admin/api/policy.go @@ -0,0 +1,517 @@ +package api + +import ( + "errors" + "net/http" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/policy" +) + +type policyAdminResponderInterface interface { + GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) + CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) + UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) + DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) + GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) + CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) + UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) + DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) + GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) + CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) + UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) + DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) +} + +// PolicyAdminResponder is responsible for writing ACME admin responses +type PolicyAdminResponder struct { + auth adminAuthority + adminDB admin.DB + acmeDB acme.DB + isLinkedCA bool +} + +// NewACMEAdminResponder returns a new ACMEAdminResponder +func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB) *PolicyAdminResponder { + + var isLinkedCA bool + if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok { + isLinkedCA = a.IsLinkedCA() + } + + return &PolicyAdminResponder{ + auth: auth, + adminDB: adminDB, + acmeDB: acmeDB, + isLinkedCA: isLinkedCA, + } +} + +// GetAuthorityPolicy handles the GET /admin/authority/policy request +func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + authorityPolicy, err := par.auth.GetAuthorityPolicy(r.Context()) + if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { + render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) + return + } + + if authorityPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) + return + } + + render.ProtoJSONStatus(w, authorityPolicy, http.StatusOK) +} + +// CreateAuthorityPolicy handles the POST /admin/authority/policy request +func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) + + if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { + render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) + return + } + + if authorityPolicy != nil { + adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy") + render.Error(w, adminErr) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) + return + } + + adm := linkedca.MustAdminFromContext(ctx) + + var createdPolicy *linkedca.Policy + if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + if isBadRequest(err) { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) + return + } + + render.Error(w, admin.WrapErrorISE(err, "error storing authority policy")) + return + } + + render.ProtoJSONStatus(w, createdPolicy, http.StatusCreated) +} + +// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request +func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) + + if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { + render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) + return + } + + if authorityPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) + return + } + + adm := linkedca.MustAdminFromContext(ctx) + + var updatedPolicy *linkedca.Policy + if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + if isBadRequest(err) { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) + return + } + + render.Error(w, admin.WrapErrorISE(err, "error updating authority policy")) + return + } + + render.ProtoJSONStatus(w, updatedPolicy, http.StatusOK) +} + +// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request +func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) + + if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { + render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) + return + } + + if authorityPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) + return + } + + if err := par.auth.RemoveAuthorityPolicy(ctx); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy")) + return + } + + render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) +} + +// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request +func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + prov := linkedca.MustProvisionerFromContext(r.Context()) + + provisionerPolicy := prov.GetPolicy() + if provisionerPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) + return + } + + render.ProtoJSONStatus(w, provisionerPolicy, http.StatusOK) +} + +// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request +func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + prov := linkedca.MustProvisionerFromContext(ctx) + + provisionerPolicy := prov.GetPolicy() + if provisionerPolicy != nil { + adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name) + render.Error(w, adminErr) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) + return + } + + prov.Policy = newPolicy + + if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { + if isBadRequest(err) { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy")) + return + } + + render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy")) + return + } + + render.ProtoJSONStatus(w, newPolicy, http.StatusCreated) +} + +// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request +func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + prov := linkedca.MustProvisionerFromContext(ctx) + + provisionerPolicy := prov.GetPolicy() + if provisionerPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) + return + } + + prov.Policy = newPolicy + if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { + if isBadRequest(err) { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy")) + return + } + + render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy")) + return + } + + render.ProtoJSONStatus(w, newPolicy, http.StatusOK) +} + +// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request +func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + prov := linkedca.MustProvisionerFromContext(ctx) + + if prov.Policy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) + return + } + + // remove the policy + prov.Policy = nil + + if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy")) + return + } + + render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) +} + +func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + eak := linkedca.MustExternalAccountKeyFromContext(ctx) + + eakPolicy := eak.GetPolicy() + if eakPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) + return + } + + render.ProtoJSONStatus(w, eakPolicy, http.StatusOK) +} + +func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + prov := linkedca.MustProvisionerFromContext(ctx) + eak := linkedca.MustExternalAccountKeyFromContext(ctx) + + eakPolicy := eak.GetPolicy() + if eakPolicy != nil { + adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id) + render.Error(w, adminErr) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) + return + } + + eak.Policy = newPolicy + + acmeEAK := linkedEAKToCertificates(eak) + if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy")) + return + } + + render.ProtoJSONStatus(w, newPolicy, http.StatusCreated) +} + +func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + prov := linkedca.MustProvisionerFromContext(ctx) + eak := linkedca.MustExternalAccountKeyFromContext(ctx) + + eakPolicy := eak.GetPolicy() + if eakPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) + return + } + + var newPolicy = new(linkedca.Policy) + if err := read.ProtoJSON(r.Body, newPolicy); err != nil { + render.Error(w, err) + return + } + + newPolicy.Deduplicate() + + if err := validatePolicy(newPolicy); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) + return + } + + eak.Policy = newPolicy + acmeEAK := linkedEAKToCertificates(eak) + if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy")) + return + } + + render.ProtoJSONStatus(w, newPolicy, http.StatusOK) +} + +func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + + if err := par.blockLinkedCA(); err != nil { + render.Error(w, err) + return + } + + ctx := r.Context() + prov := linkedca.MustProvisionerFromContext(ctx) + eak := linkedca.MustExternalAccountKeyFromContext(ctx) + + eakPolicy := eak.GetPolicy() + if eakPolicy == nil { + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) + return + } + + // remove the policy + eak.Policy = nil + + acmeEAK := linkedEAKToCertificates(eak) + if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy")) + return + } + + render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) +} + +// blockLinkedCA blocks all API operations on linked deployments +func (par *PolicyAdminResponder) blockLinkedCA() error { + // temporary blocking linked deployments + if par.isLinkedCA { + return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + } + return nil +} + +// isBadRequest checks if an error should result in a bad request error +// returned to the client. +func isBadRequest(err error) bool { + var pe *authority.PolicyError + isPolicyError := errors.As(err, &pe) + return isPolicyError && (pe.Typ == authority.AdminLockOut || pe.Typ == authority.EvaluationFailure || pe.Typ == authority.ConfigurationFailure) +} + +func validatePolicy(p *linkedca.Policy) error { + + // convert the policy; return early if nil + options := policy.LinkedToCertificates(p) + if options == nil { + return nil + } + + var err error + + // Initialize a temporary x509 allow/deny policy engine + if _, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil { + return err + } + + // Initialize a temporary SSH allow/deny policy engine for host certificates + if _, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil { + return err + } + + // Initialize a temporary SSH allow/deny policy engine for user certificates + if _, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil { + return err + } + + return nil +} diff --git a/authority/admin/api/policy_test.go b/authority/admin/api/policy_test.go new file mode 100644 index 00000000..1e70db52 --- /dev/null +++ b/authority/admin/api/policy_test.go @@ -0,0 +1,2711 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/encoding/protojson" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" +) + +type fakeLinkedCA struct { + admin.MockDB +} + +func (f *fakeLinkedCA) IsLinkedCA() bool { + return true +} + +// testAdminError is an error type that models the expected +// error body returned. +type testAdminError struct { + Type string `json:"type"` + Message string `json:"message"` + Detail string `json:"detail"` +} + +type testX509Policy struct { + Allow *testX509Names `json:"allow,omitempty"` + Deny *testX509Names `json:"deny,omitempty"` + AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"` +} + +type testX509Names struct { + CommonNames []string `json:"commonNames,omitempty"` + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ips,omitempty"` + EmailAddresses []string `json:"emails,omitempty"` + URIDomains []string `json:"uris,omitempty"` +} + +type testSSHPolicy struct { + User *testSSHUserPolicy `json:"user,omitempty"` + Host *testSSHHostPolicy `json:"host,omitempty"` +} + +type testSSHHostPolicy struct { + Allow *testSSHHostNames `json:"allow,omitempty"` + Deny *testSSHHostNames `json:"deny,omitempty"` +} + +type testSSHHostNames struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ips,omitempty"` + Principals []string `json:"principals,omitempty"` +} + +type testSSHUserPolicy struct { + Allow *testSSHUserNames `json:"allow,omitempty"` + Deny *testSSHUserNames `json:"deny,omitempty"` +} + +type testSSHUserNames struct { + EmailAddresses []string `json:"emails,omitempty"` + Principals []string `json:"principals,omitempty"` +} + +// testPolicyResponse models the Policy API JSON response +type testPolicyResponse struct { + X509 *testX509Policy `json:"x509,omitempty"` + SSH *testSSHPolicy `json:"ssh,omitempty"` +} + +func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { + ctx := context.Background() + err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") + err.Message = "error retrieving authority policy: force" + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorServerInternalType, "force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/auth.GetAuthorityPolicy-not-found": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") + err.Message = "authority policy does not exist" + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + err: err, + statusCode: 404, + } + }, + "ok": func(t *testing.T) test { + ctx := context.Background() + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/16"}, + Emails: []string{"@example.com"}, + Uris: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad.local"}, + Ips: []string{"10.0.0.30"}, + Emails: []string{"bad@example.com"}, + Uris: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.example.com"}, + Ips: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"bad@example.com"}, + Ips: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + } + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + }, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/16"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &testX509Names{ + DNSDomains: []string{"bad.local"}, + IPRanges: []string{"10.0.0.30"}, + EmailAddresses: []string{"bad@example.com"}, + URIDomains: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + SSH: &testSSHPolicy{ + User: &testSSHUserPolicy{ + Allow: &testSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &testSSHUserNames{ + EmailAddresses: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &testSSHHostPolicy{ + Allow: &testSSHHostNames{ + DNSDomains: []string{"*.example.com"}, + IPRanges: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &testSSHHostNames{ + DNSDomains: []string{"bad@example.com"}, + IPRanges: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) + + req := httptest.NewRequest("GET", "/foo", nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.GetAuthorityPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + }) + } +} + +func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + acmeDB acme.DB + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { + ctx := context.Background() + err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") + err.Message = "error retrieving authority policy: force" + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorServerInternalType, "force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/existing-policy": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorConflictType, "authority already has a policy") + err.Message = "authority already has a policy" + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{}, nil + }, + }, + err: err, + statusCode: 409, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + ctx := context.Background() + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + ctx := context.Background() + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/CreateAuthorityPolicy-policy-admin-lockout-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error storing authority policy") + adminErr.Message = "error storing authority policy: admin lock out" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + MockCreateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + } + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/CreateAuthorityPolicy-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error storing authority policy: force") + adminErr.Message = "error storing authority policy: force" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + MockCreateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, &authority.PolicyError{ + Typ: authority.StoreFailure, + Err: errors.New("force"), + } + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + MockCreateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return policy, nil + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 201, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.CreateAuthorityPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + acmeDB acme.DB + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { + ctx := context.Background() + err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") + err.Message = "error retrieving authority policy: force" + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorServerInternalType, "force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") + err.Message = "authority policy does not exist" + err.Status = http.StatusNotFound + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, nil + }, + }, + err: err, + statusCode: 404, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + ctx := context.Background() + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + ctx := context.Background() + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/UpdateAuthorityPolicy-policy-admin-lockout-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error updating authority policy: force") + adminErr.Message = "error updating authority policy: admin lock out" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + } + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/UpdateAuthorityPolicy-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating authority policy: force") + adminErr.Message = "error updating authority policy: force" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return nil, &authority.PolicyError{ + Typ: authority.StoreFailure, + Err: errors.New("force"), + } + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + ctx := context.Background() + ctx = linkedca.NewContextWithAdmin(ctx, adm) + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { + return policy, nil + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{ + adm, + { + Subject: "anotherAdmin", + }, + }, nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.UpdateAuthorityPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + acmeDB acme.DB + err *admin.Error + statusCode int + } + + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { + ctx := context.Background() + err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") + err.Message = "error retrieving authority policy: force" + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorServerInternalType, "force") + }, + }, + err: err, + statusCode: 500, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") + err.Message = "authority policy does not exist" + err.Status = http.StatusNotFound + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, nil + }, + }, + err: err, + statusCode: 404, + } + }, + "fail/auth.RemoveAuthorityPolicy-error": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + ctx := context.Background() + err := admin.NewErrorISE("error deleting authority policy: force") + err.Message = "error deleting authority policy: force" + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockRemoveAuthorityPolicy: func(ctx context.Context) error { + return errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + ctx := context.Background() + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return policy, nil + }, + MockRemoveAuthorityPolicy: func(ctx context.Context) error { + return nil + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.DeleteAuthorityPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + res.Body.Close() + response := DeleteResponse{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equal(t, "ok", response.Status) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + }) + } +} + +func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + ctx context.Context + acmeDB acme.DB + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/prov-no-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{} + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist") + err.Message = "provisioner policy does not exist" + return test{ + ctx: ctx, + err: err, + statusCode: 404, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/16"}, + Emails: []string{"@example.com"}, + Uris: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad.local"}, + Ips: []string{"10.0.0.30"}, + Emails: []string{"bad@example.com"}, + Uris: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.example.com"}, + Ips: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"bad@example.com"}, + Ips: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + } + prov := &linkedca.Provisioner{ + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + return test{ + ctx: ctx, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/16"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &testX509Names{ + DNSDomains: []string{"bad.local"}, + IPRanges: []string{"10.0.0.30"}, + EmailAddresses: []string{"bad@example.com"}, + URIDomains: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + SSH: &testSSHPolicy{ + User: &testSSHUserPolicy{ + Allow: &testSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &testSSHUserNames{ + EmailAddresses: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &testSSHHostPolicy{ + Allow: &testSSHHostNames{ + DNSDomains: []string{"*.example.com"}, + IPRanges: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &testSSHHostNames{ + DNSDomains: []string{"bad@example.com"}, + IPRanges: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + + req := httptest.NewRequest("GET", "/foo", nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.GetProvisionerPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/existing-policy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewError(admin.ErrorConflictType, "provisioner provName already has a policy") + err.Message = "provisioner provName already has a policy" + return test{ + ctx: ctx, + err: err, + statusCode: 409, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/auth.UpdateProvisioner-policy-admin-lockout-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error creating provisioner policy") + adminErr.Message = "error creating provisioner policy: admin lock out" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + } + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error creating provisioner policy: force") + adminErr.Message = "error creating provisioner policy: force" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return &authority.PolicyError{ + Typ: authority.StoreFailure, + Err: errors.New("force"), + } + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 201, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.CreateProvisionerPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { + type test struct { + auth adminAuthority + body []byte + adminDB admin.DB + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist") + err.Message = "provisioner policy does not exist" + return test{ + ctx: ctx, + err: err, + statusCode: 404, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, admin.NewError(admin.ErrorNotFoundType, "not found") + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/auth.UpdateProvisioner-policy-admin-lockout-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error updating provisioner policy") + adminErr.Message = "error updating provisioner policy: admin lock out" + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + } + }, + }, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating provisioner policy: force") + adminErr.Message = "error updating provisioner policy: force" + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return &authority.PolicyError{ + Typ: authority.StoreFailure, + Err: errors.New("force"), + } + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + adm := &linkedca.Admin{ + Subject: "step", + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: policy, + } + ctx := linkedca.NewContextWithAdmin(context.Background(), adm) + ctx = linkedca.NewContextWithProvisioner(ctx, prov) + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.UpdateProvisionerPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { + type test struct { + auth adminAuthority + adminDB admin.DB + body []byte + ctx context.Context + acmeDB acme.DB + err *admin.Error + statusCode int + } + + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist") + err.Message = "provisioner policy does not exist" + return test{ + ctx: ctx, + err: err, + statusCode: 404, + } + }, + "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: &linkedca.Policy{}, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + err := admin.NewErrorISE("error deleting provisioner policy: force") + err.Message = "error deleting provisioner policy: force" + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + Policy: &linkedca.Policy{}, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + return test{ + ctx: ctx, + auth: &mockAdminAuthority{ + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + return nil + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.DeleteProvisionerPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + res.Body.Close() + response := DeleteResponse{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equal(t, "ok", response.Status) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + }) + } +} + +func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { + type test struct { + ctx context.Context + acmeDB acme.DB + adminDB admin.DB + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist") + err.Message = "ACME EAK policy does not exist" + return test{ + ctx: ctx, + err: err, + statusCode: 404, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"10.0.0.0/16"}, + Emails: []string{"@example.com"}, + Uris: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad.local"}, + Ips: []string{"10.0.0.30"}, + Emails: []string{"bad@example.com"}, + Uris: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.example.com"}, + Ips: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"bad@example.com"}, + Ips: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + return test{ + ctx: ctx, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"10.0.0.0/16"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"example.com"}, + CommonNames: []string{"test"}, + }, + Deny: &testX509Names{ + DNSDomains: []string{"bad.local"}, + IPRanges: []string{"10.0.0.30"}, + EmailAddresses: []string{"bad@example.com"}, + URIDomains: []string{"notexample.com"}, + CommonNames: []string{"bad"}, + }, + }, + SSH: &testSSHPolicy{ + User: &testSSHUserPolicy{ + Allow: &testSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"*"}, + }, + Deny: &testSSHUserNames{ + EmailAddresses: []string{"bad@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &testSSHHostPolicy{ + Allow: &testSSHHostNames{ + DNSDomains: []string{"*.example.com"}, + IPRanges: []string{"10.10.0.0/16"}, + Principals: []string{"good"}, + }, + Deny: &testSSHHostNames{ + DNSDomains: []string{"bad@example.com"}, + IPRanges: []string{"10.10.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + + req := httptest.NewRequest("GET", "/foo", nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.GetACMEAccountPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { + type test struct { + acmeDB acme.DB + adminDB admin.DB + body []byte + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/existing-policy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewError(admin.ErrorConflictType, "ACME EAK eakID already has a policy") + err.Message = "ACME EAK eakID already has a policy" + return test{ + ctx: ctx, + err: err, + statusCode: 409, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/acmeDB.UpdateExternalAccountKey-error": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error creating ACME EAK policy") + adminErr.Message = "error creating ACME EAK policy: force" + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return errors.New("force") + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 201, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.CreateACMEAccountPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { + type test struct { + acmeDB acme.DB + adminDB admin.DB + body []byte + ctx context.Context + err *admin.Error + response *testPolicyResponse + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist") + err.Message = "ACME EAK policy does not exist" + return test{ + ctx: ctx, + err: err, + statusCode: 404, + } + }, + "fail/read.ProtoJSON": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") + adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" + body := []byte("{?}") + return test{ + ctx: ctx, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/validatePolicy": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") + adminErr.Message = "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" + body := []byte(` + { + "x509": { + "allow": { + "uris": [ + "https://example.com" + ] + } + } + }`) + return test{ + ctx: ctx, + body: body, + err: adminErr, + statusCode: 400, + } + }, + "fail/acmeDB.UpdateExternalAccountKey-error": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Id: "provID", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating ACME EAK policy: force") + adminErr.Message = "error updating ACME EAK policy: force" + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return errors.New("force") + }, + }, + body: body, + err: adminErr, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Id: "provID", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + body, err := protojson.Marshal(policy) + assert.NoError(t, err) + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return nil + }, + }, + body: body, + response: &testPolicyResponse{ + X509: &testX509Policy{ + Allow: &testX509Names{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.UpdateACMEAccountPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + // when the error message starts with "proto", we expect it to have + // a syntax error (in the tests). If the message doesn't start with "proto", + // we expect a full string match. + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(ae.Message, "syntax error")) + } else { + assert.Equal(t, tc.err.Message, ae.Message) + } + + return + } + + p := &testPolicyResponse{} + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.NoError(t, json.Unmarshal(body, &p)) + + assert.Equal(t, tc.response, p) + + }) + } +} + +func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { + type test struct { + body []byte + adminDB admin.DB + ctx context.Context + acmeDB acme.DB + err *admin.Error + statusCode int + } + + var tests = map[string]func(t *testing.T) test{ + "fail/linkedca": func(t *testing.T) test { + ctx := context.Background() + err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") + err.Message = "policy operations not yet supported in linked deployments" + return test{ + ctx: ctx, + adminDB: &fakeLinkedCA{}, + err: err, + statusCode: 501, + } + }, + "fail/no-existing-policy": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Name: "provName", + } + eak := &linkedca.EABKey{ + Id: "eakID", + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist") + err.Message = "ACME EAK policy does not exist" + return test{ + ctx: ctx, + err: err, + statusCode: 404, + } + }, + "fail/acmeDB.UpdateExternalAccountKey-error": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Id: "provID", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + err := admin.NewErrorISE("error deleting ACME EAK policy: force") + err.Message = "error deleting ACME EAK policy: force" + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return errors.New("force") + }, + }, + err: err, + statusCode: 500, + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + prov := &linkedca.Provisioner{ + Name: "provName", + Id: "provID", + } + eak := &linkedca.EABKey{ + Id: "eakID", + Policy: policy, + } + ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) + ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) + return test{ + ctx: ctx, + acmeDB: &acme.MockDB{ + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + assert.Equal(t, "provID", provisionerID) + assert.Equal(t, "eakID", eak.ID) + return nil + }, + }, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + + par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + + par.DeleteACMEAccountPolicy(w, req) + res := w.Result() + + assert.Equal(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.NoError(t, err) + + ae := testAdminError{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equal(t, tc.err.Type, ae.Type) + assert.Equal(t, tc.err.Message, ae.Message) + assert.Equal(t, tc.err.StatusCode(), res.StatusCode) + assert.Equal(t, tc.err.Detail, ae.Detail) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + assert.NoError(t, err) + res.Body.Close() + response := DeleteResponse{} + assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equal(t, "ok", response.Status) + assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) + + }) + } +} + +func Test_isBadRequest(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "nil", + err: nil, + want: false, + }, + { + name: "no-policy-error", + err: errors.New("some error"), + want: false, + }, + { + name: "no-bad-request", + err: &authority.PolicyError{ + Typ: authority.InternalFailure, + Err: errors.New("error"), + }, + want: false, + }, + { + name: "bad-request", + err: &authority.PolicyError{ + Typ: authority.AdminLockOut, + Err: errors.New("admin lock out"), + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isBadRequest(tt.err); got != tt.want { + t.Errorf("isBadRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_validatePolicy(t *testing.T) { + type args struct { + p *linkedca.Policy + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "nil", + args: args{ + p: nil, + }, + wantErr: false, + }, + { + name: "x509", + args: args{ + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"**.local"}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ssh user", + args: args{ + p: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@@example.com"}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ssh host", + args: args{ + p: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"**.local"}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ok", + args: args{ + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + }, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := validatePolicy(tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("validatePolicy() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/admin/api/provisioner_test.go b/authority/admin/api/provisioner_test.go index 6d5024f2..486b6cda 100644 --- a/authority/admin/api/provisioner_test.go +++ b/authority/admin/api/provisioner_test.go @@ -8,18 +8,21 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "time" "github.com/go-chi/chi" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.step.sm/linkedca" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/types/known/timestamppb" ) func TestHandler_GetProvisioner(t *testing.T) { @@ -335,12 +338,12 @@ func TestHandler_CreateProvisioner(t *testing.T) { return test{ ctx: context.Background(), body: body, - statusCode: 500, - err: &admin.Error{ // TODO(hs): this probably needs a better error - Type: "", - Status: 500, - Detail: "", - Message: "", + statusCode: 400, + err: &admin.Error{ + Type: "badRequest", + Status: 400, + Detail: "bad request", + Message: "proto: syntax error (line 1:2): invalid value !", }, } }, @@ -423,9 +426,15 @@ func TestHandler_CreateProvisioner(t *testing.T) { assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) - assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(adminErr.Message, "syntax error")) + } else { + assert.Equals(t, tc.err.Message, adminErr.Message) + } + return } @@ -616,12 +625,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: context.Background(), body: body, - statusCode: 500, - err: &admin.Error{ // TODO(hs): this probably needs a better error - Type: "", - Status: 500, - Detail: "", - Message: "", + statusCode: 400, + err: &admin.Error{ + Type: "badRequest", + Status: 400, + Detail: "bad request", + Message: "proto: syntax error (line 1:2): invalid value !", }, } }, @@ -1074,9 +1083,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) { assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) - assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + if strings.HasPrefix(tc.err.Message, "proto:") { + assert.True(t, strings.Contains(adminErr.Message, "syntax error")) + } else { + assert.Equals(t, tc.err.Message, adminErr.Message) + } + return } diff --git a/authority/admin/db.go b/authority/admin/db.go index bf34a3c2..0c0e7767 100644 --- a/authority/admin/db.go +++ b/authority/admin/db.go @@ -69,6 +69,11 @@ type DB interface { GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error DeleteAdmin(ctx context.Context, id string) error + + CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error + GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) + UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error + DeleteAuthorityPolicy(ctx context.Context) error } // MockDB is an implementation of the DB interface that should only be used as @@ -86,6 +91,11 @@ type MockDB struct { MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error MockDeleteAdmin func(ctx context.Context, id string) error + MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) + MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error + MockDeleteAuthorityPolicy func(ctx context.Context) error + MockError error MockRet1 interface{} } @@ -179,3 +189,35 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error { } return m.MockError } + +// CreateAuthorityPolicy mock +func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + if m.MockCreateAuthorityPolicy != nil { + return m.MockCreateAuthorityPolicy(ctx, policy) + } + return m.MockError +} + +// GetAuthorityPolicy mock +func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + if m.MockGetAuthorityPolicy != nil { + return m.MockGetAuthorityPolicy(ctx) + } + return m.MockRet1.(*linkedca.Policy), m.MockError +} + +// UpdateAuthorityPolicy mock +func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + if m.MockUpdateAuthorityPolicy != nil { + return m.MockUpdateAuthorityPolicy(ctx, policy) + } + return m.MockError +} + +// DeleteAuthorityPolicy mock +func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error { + if m.MockDeleteAuthorityPolicy != nil { + return m.MockDeleteAuthorityPolicy(ctx) + } + return m.MockError +} diff --git a/authority/admin/db/nosql/nosql.go b/authority/admin/db/nosql/nosql.go index 22b049f5..32e05d92 100644 --- a/authority/admin/db/nosql/nosql.go +++ b/authority/admin/db/nosql/nosql.go @@ -11,8 +11,9 @@ import ( ) var ( - adminsTable = []byte("admins") - provisionersTable = []byte("provisioners") + adminsTable = []byte("admins") + provisionersTable = []byte("provisioners") + authorityPoliciesTable = []byte("authority_policies") ) // DB is a struct that implements the AdminDB interface. @@ -23,7 +24,7 @@ type DB struct { // New configures and returns a new Authority DB backend implemented using a nosql DB. func New(db nosqlDB.DB, authorityID string) (*DB, error) { - tables := [][]byte{adminsTable, provisionersTable} + tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable} for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", diff --git a/authority/admin/db/nosql/policy.go b/authority/admin/db/nosql/policy.go new file mode 100644 index 00000000..d4f2e9f9 --- /dev/null +++ b/authority/admin/db/nosql/policy.go @@ -0,0 +1,339 @@ +package nosql + +import ( + "context" + "encoding/json" + "fmt" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/nosql" +) + +type dbX509Policy struct { + Allow *dbX509Names `json:"allow,omitempty"` + Deny *dbX509Names `json:"deny,omitempty"` + AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"` +} + +type dbX509Names struct { + CommonNames []string `json:"cn,omitempty"` + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + URIDomains []string `json:"uri,omitempty"` +} + +type dbSSHPolicy struct { + // User contains SSH user certificate options. + User *dbSSHUserPolicy `json:"user,omitempty"` + // Host contains SSH host certificate options. + Host *dbSSHHostPolicy `json:"host,omitempty"` +} + +type dbSSHHostPolicy struct { + Allow *dbSSHHostNames `json:"allow,omitempty"` + Deny *dbSSHHostNames `json:"deny,omitempty"` +} + +type dbSSHHostNames struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + Principals []string `json:"principal,omitempty"` +} + +type dbSSHUserPolicy struct { + Allow *dbSSHUserNames `json:"allow,omitempty"` + Deny *dbSSHUserNames `json:"deny,omitempty"` +} + +type dbSSHUserNames struct { + EmailAddresses []string `json:"email,omitempty"` + Principals []string `json:"principal,omitempty"` +} + +type dbPolicy struct { + X509 *dbX509Policy `json:"x509,omitempty"` + SSH *dbSSHPolicy `json:"ssh,omitempty"` +} + +type dbAuthorityPolicy struct { + ID string `json:"id"` + AuthorityID string `json:"authorityID"` + Policy *dbPolicy `json:"policy,omitempty"` +} + +func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy { + if dbap == nil { + return nil + } + return dbToLinked(dbap.Policy) +} + +func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) { + data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID)) + if nosql.IsErrNotFound(err) { + return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found") + } else if err != nil { + return nil, fmt.Errorf("error loading authority policy: %w", err) + } + return data, nil +} + +func (db *DB) unmarshalDBAuthorityPolicy(data []byte) (*dbAuthorityPolicy, error) { + if len(data) == 0 { + return nil, nil + } + var dba = new(dbAuthorityPolicy) + if err := json.Unmarshal(data, dba); err != nil { + return nil, fmt.Errorf("error unmarshaling policy bytes into dbAuthorityPolicy: %w", err) + } + return dba, nil +} + +func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) { + data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID) + if err != nil { + return nil, err + } + dbap, err := db.unmarshalDBAuthorityPolicy(data) + if err != nil { + return nil, err + } + if dbap == nil { + return nil, nil + } + if dbap.AuthorityID != authorityID { + return nil, admin.NewError(admin.ErrorAuthorityMismatchType, + "authority policy is not owned by authority %s", authorityID) + } + return dbap, nil +} + +func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + + dbap := &dbAuthorityPolicy{ + ID: db.authorityID, + AuthorityID: db.authorityID, + Policy: linkedToDB(policy), + } + + if err := db.save(ctx, dbap.ID, dbap, nil, "authority_policy", authorityPoliciesTable); err != nil { + return admin.WrapErrorISE(err, "error creating authority policy") + } + + return nil +} + +func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return nil, err + } + + return dbap.convert(), nil +} + +func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return err + } + + dbap := &dbAuthorityPolicy{ + ID: db.authorityID, + AuthorityID: db.authorityID, + Policy: linkedToDB(policy), + } + + if err := db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable); err != nil { + return admin.WrapErrorISE(err, "error updating authority policy") + } + + return nil +} + +func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error { + old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + if err != nil { + return err + } + + if err := db.save(ctx, old.ID, nil, old, "authority_policy", authorityPoliciesTable); err != nil { + return admin.WrapErrorISE(err, "error deleting authority policy") + } + + return nil +} + +func dbToLinked(p *dbPolicy) *linkedca.Policy { + if p == nil { + return nil + } + r := &linkedca.Policy{} + if x509 := p.X509; x509 != nil { + r.X509 = &linkedca.X509Policy{} + if allow := x509.Allow; allow != nil { + r.X509.Allow = &linkedca.X509Names{} + r.X509.Allow.Dns = allow.DNSDomains + r.X509.Allow.Emails = allow.EmailAddresses + r.X509.Allow.Ips = allow.IPRanges + r.X509.Allow.Uris = allow.URIDomains + r.X509.Allow.CommonNames = allow.CommonNames + } + if deny := x509.Deny; deny != nil { + r.X509.Deny = &linkedca.X509Names{} + r.X509.Deny.Dns = deny.DNSDomains + r.X509.Deny.Emails = deny.EmailAddresses + r.X509.Deny.Ips = deny.IPRanges + r.X509.Deny.Uris = deny.URIDomains + r.X509.Deny.CommonNames = deny.CommonNames + } + r.X509.AllowWildcardNames = x509.AllowWildcardNames + } + if ssh := p.SSH; ssh != nil { + r.Ssh = &linkedca.SSHPolicy{} + if host := ssh.Host; host != nil { + r.Ssh.Host = &linkedca.SSHHostPolicy{} + if allow := host.Allow; allow != nil { + r.Ssh.Host.Allow = &linkedca.SSHHostNames{} + r.Ssh.Host.Allow.Dns = allow.DNSDomains + r.Ssh.Host.Allow.Ips = allow.IPRanges + r.Ssh.Host.Allow.Principals = allow.Principals + } + if deny := host.Deny; deny != nil { + r.Ssh.Host.Deny = &linkedca.SSHHostNames{} + r.Ssh.Host.Deny.Dns = deny.DNSDomains + r.Ssh.Host.Deny.Ips = deny.IPRanges + r.Ssh.Host.Deny.Principals = deny.Principals + } + } + if user := ssh.User; user != nil { + r.Ssh.User = &linkedca.SSHUserPolicy{} + if allow := user.Allow; allow != nil { + r.Ssh.User.Allow = &linkedca.SSHUserNames{} + r.Ssh.User.Allow.Emails = allow.EmailAddresses + r.Ssh.User.Allow.Principals = allow.Principals + } + if deny := user.Deny; deny != nil { + r.Ssh.User.Deny = &linkedca.SSHUserNames{} + r.Ssh.User.Deny.Emails = deny.EmailAddresses + r.Ssh.User.Deny.Principals = deny.Principals + } + } + } + + return r +} + +func linkedToDB(p *linkedca.Policy) *dbPolicy { + + if p == nil { + return nil + } + + // return early if x509 nor SSH is set + if p.GetX509() == nil && p.GetSsh() == nil { + return nil + } + + r := &dbPolicy{} + // fill x509 policy configuration + if x509 := p.GetX509(); x509 != nil { + r.X509 = &dbX509Policy{} + if allow := x509.GetAllow(); allow != nil { + r.X509.Allow = &dbX509Names{} + if allow.Dns != nil { + r.X509.Allow.DNSDomains = allow.Dns + } + if allow.Ips != nil { + r.X509.Allow.IPRanges = allow.Ips + } + if allow.Emails != nil { + r.X509.Allow.EmailAddresses = allow.Emails + } + if allow.Uris != nil { + r.X509.Allow.URIDomains = allow.Uris + } + if allow.CommonNames != nil { + r.X509.Allow.CommonNames = allow.CommonNames + } + } + if deny := x509.GetDeny(); deny != nil { + r.X509.Deny = &dbX509Names{} + if deny.Dns != nil { + r.X509.Deny.DNSDomains = deny.Dns + } + if deny.Ips != nil { + r.X509.Deny.IPRanges = deny.Ips + } + if deny.Emails != nil { + r.X509.Deny.EmailAddresses = deny.Emails + } + if deny.Uris != nil { + r.X509.Deny.URIDomains = deny.Uris + } + if deny.CommonNames != nil { + r.X509.Deny.CommonNames = deny.CommonNames + } + } + + r.X509.AllowWildcardNames = x509.GetAllowWildcardNames() + } + + // fill ssh policy configuration + if ssh := p.GetSsh(); ssh != nil { + r.SSH = &dbSSHPolicy{} + if host := ssh.GetHost(); host != nil { + r.SSH.Host = &dbSSHHostPolicy{} + if allow := host.GetAllow(); allow != nil { + r.SSH.Host.Allow = &dbSSHHostNames{} + if allow.Dns != nil { + r.SSH.Host.Allow.DNSDomains = allow.Dns + } + if allow.Ips != nil { + r.SSH.Host.Allow.IPRanges = allow.Ips + } + if allow.Principals != nil { + r.SSH.Host.Allow.Principals = allow.Principals + } + } + if deny := host.GetDeny(); deny != nil { + r.SSH.Host.Deny = &dbSSHHostNames{} + if deny.Dns != nil { + r.SSH.Host.Deny.DNSDomains = deny.Dns + } + if deny.Ips != nil { + r.SSH.Host.Deny.IPRanges = deny.Ips + } + if deny.Principals != nil { + r.SSH.Host.Deny.Principals = deny.Principals + } + } + } + if user := ssh.GetUser(); user != nil { + r.SSH.User = &dbSSHUserPolicy{} + if allow := user.GetAllow(); allow != nil { + r.SSH.User.Allow = &dbSSHUserNames{} + if allow.Emails != nil { + r.SSH.User.Allow.EmailAddresses = allow.Emails + } + if allow.Principals != nil { + r.SSH.User.Allow.Principals = allow.Principals + } + } + if deny := user.GetDeny(); deny != nil { + r.SSH.User.Deny = &dbSSHUserNames{} + if deny.Emails != nil { + r.SSH.User.Deny.EmailAddresses = deny.Emails + } + if deny.Principals != nil { + r.SSH.User.Deny.Principals = deny.Principals + } + } + } + } + + return r +} diff --git a/authority/admin/db/nosql/policy_test.go b/authority/admin/db/nosql/policy_test.go new file mode 100644 index 00000000..3ffded6b --- /dev/null +++ b/authority/admin/db/nosql/policy_test.go @@ -0,0 +1,1206 @@ +package nosql + +import ( + "context" + "encoding/json" + "errors" + "reflect" + "testing" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" + "go.step.sm/linkedca" +) + +func TestDB_getDBAuthorityPolicyBytes(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "ok": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return []byte("foo"), nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if b, err := d.getDBAuthorityPolicyBytes(tc.ctx, tc.authorityID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, string(b), "foo") + } + }) + } +} + +func TestDB_getDBAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + db nosql.DB + err error + adminErr *admin.Error + dbap *dbAuthorityPolicy + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling policy bytes into dbAuthorityPolicy"), + } + }, + "fail/authorityID-error": func(t *testing.T) test { + dbp := &dbAuthorityPolicy{ + ID: "ID", + AuthorityID: "diffAuthID", + Policy: linkedToDB(&linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }), + } + b, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return b, nil + }, + }, + adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, + "authority policy is not owned by authority authID"), + } + }, + "ok/empty-bytes": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return []byte{}, nil + }, + }, + } + }, + "ok": func(t *testing.T) test { + dbap := &dbAuthorityPolicy{ + ID: "ID", + AuthorityID: authID, + Policy: linkedToDB(&linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }), + } + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return b, nil + }, + }, + dbap: dbap, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + dbp, err := d.getDBAuthorityPolicy(tc.ctx, tc.authorityID) + switch { + case err != nil: + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + case assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) && tc.dbap == nil: + assert.Nil(t, dbp) + case assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr): + assert.Equals(t, dbp.ID, "ID") + assert.Equals(t, dbp.AuthorityID, tc.dbap.AuthorityID) + assert.Equals(t, dbp.Policy, tc.dbap.Policy) + } + }) + } +} + +func TestDB_CreateAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + policy *linkedca.Policy + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/save-error": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, linkedToDB(policy)) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error creating authority policy: error saving authority authority_policy: force"), + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, old, nil) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, linkedToDB(policy)) + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + if err := d.CreateAuthorityPolicy(tc.ctx, tc.policy); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } + }) + } +} + +func TestDB_GetAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + policy *linkedca.Policy + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(policy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + got, err := d.GetAuthorityPolicy(tc.ctx) + if err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + return + } + + assert.NotNil(t, got) + assert.Equals(t, tc.policy, got) + }) + } +} + +func TestDB_UpdateAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + policy *linkedca.Policy + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(oldPolicy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, linkedToDB(policy)) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error updating authority policy: error saving authority authority_policy: force"), + } + }, + "ok": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(oldPolicy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, linkedToDB(policy)) + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + if err := d.UpdateAuthorityPolicy(tc.ctx, tc.policy); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + return + } + }) + } +} + +func TestDB_DeleteAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(oldPolicy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + assert.Equals(t, nil, nu) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error deleting authority policy: error saving authority authority_policy: force"), + } + }, + "ok": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: linkedToDB(oldPolicy), + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + assert.Equals(t, nil, nu) + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + if err := d.DeleteAuthorityPolicy(tc.ctx); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + return + } + }) + } +} + +func Test_linkedToDB(t *testing.T) { + type args struct { + p *linkedca.Policy + } + tests := []struct { + name string + args args + want *dbPolicy + }{ + { + name: "nil policy", + args: args{ + p: nil, + }, + want: nil, + }, + { + name: "no x509 nor ssh", + args: args{ + p: &linkedca.Policy{}, + }, + want: nil, + }, + { + name: "x509", + args: args{ + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Emails: []string{"@example.com"}, + Uris: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Emails: []string{"root@example.com"}, + Uris: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + }, + }, + want: &dbPolicy{ + X509: &dbX509Policy{ + Allow: &dbX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &dbX509Names{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + EmailAddresses: []string{"root@example.com"}, + URIDomains: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + }, + }, + { + name: "ssh user", + args: args{ + p: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + want: &dbPolicy{ + SSH: &dbSSHPolicy{ + User: &dbSSHUserPolicy{ + Allow: &dbSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &dbSSHUserNames{ + EmailAddresses: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + { + name: "full ssh policy", + args: args{ + p: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + want: &dbPolicy{ + SSH: &dbSSHPolicy{ + Host: &dbSSHHostPolicy{ + Allow: &dbSSHHostNames{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &dbSSHHostNames{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + { + name: "full policy", + args: args{ + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Emails: []string{"@example.com"}, + Uris: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Emails: []string{"root@example.com"}, + Uris: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + want: &dbPolicy{ + X509: &dbX509Policy{ + Allow: &dbX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &dbX509Names{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + EmailAddresses: []string{"root@example.com"}, + URIDomains: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + SSH: &dbSSHPolicy{ + User: &dbSSHUserPolicy{ + Allow: &dbSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &dbSSHUserNames{ + EmailAddresses: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &dbSSHHostPolicy{ + Allow: &dbSSHHostNames{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &dbSSHHostNames{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := linkedToDB(tt.args.p); !reflect.DeepEqual(got, tt.want) { + t.Errorf("linkedToDB() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_dbToLinked(t *testing.T) { + type args struct { + p *dbPolicy + } + tests := []struct { + name string + args args + want *linkedca.Policy + }{ + { + name: "nil policy", + args: args{ + p: nil, + }, + want: nil, + }, + { + name: "x509", + args: args{ + p: &dbPolicy{ + X509: &dbX509Policy{ + Allow: &dbX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &dbX509Names{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + EmailAddresses: []string{"root@example.com"}, + URIDomains: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Emails: []string{"@example.com"}, + Uris: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Emails: []string{"root@example.com"}, + Uris: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + }, + }, + { + name: "ssh user", + args: args{ + p: &dbPolicy{ + SSH: &dbSSHPolicy{ + User: &dbSSHUserPolicy{ + Allow: &dbSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &dbSSHUserNames{ + EmailAddresses: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + want: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + { + name: "ssh host", + args: args{ + p: &dbPolicy{ + SSH: &dbSSHPolicy{ + Host: &dbSSHHostPolicy{ + Allow: &dbSSHHostNames{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &dbSSHHostNames{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + want: &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + { + name: "full policy", + args: args{ + p: &dbPolicy{ + X509: &dbX509Policy{ + Allow: &dbX509Names{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + EmailAddresses: []string{"@example.com"}, + URIDomains: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &dbX509Names{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + EmailAddresses: []string{"root@example.com"}, + URIDomains: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + SSH: &dbSSHPolicy{ + User: &dbSSHUserPolicy{ + Allow: &dbSSHUserNames{ + EmailAddresses: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &dbSSHUserNames{ + EmailAddresses: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &dbSSHHostPolicy{ + Allow: &dbSSHHostNames{ + DNSDomains: []string{"*.local"}, + IPRanges: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &dbSSHHostNames{ + DNSDomains: []string{"badhost.local"}, + IPRanges: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Emails: []string{"@example.com"}, + Uris: []string{"*.example.com"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Emails: []string{"root@example.com"}, + Uris: []string{"bad.example.com"}, + CommonNames: []string{"bad name"}, + }, + AllowWildcardNames: true, + }, + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@example.com"}, + Principals: []string{"root"}, + }, + }, + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + Ips: []string{"192.168.0.1/24"}, + Principals: []string{"host"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + Ips: []string{"192.168.0.30"}, + Principals: []string{"bad"}, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := dbToLinked(tt.args.p); !reflect.DeepEqual(got, tt.want) { + t.Errorf("dbToLinked() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/admin/errors.go b/authority/admin/errors.go index baa32dd9..2cf0c0e5 100644 --- a/authority/admin/errors.go +++ b/authority/admin/errors.go @@ -24,10 +24,12 @@ const ( ErrorBadRequestType // ErrorNotImplementedType not implemented. ErrorNotImplementedType - // ErrorUnauthorizedType internal server error. + // ErrorUnauthorizedType unauthorized. ErrorUnauthorizedType // ErrorServerInternalType internal server error. ErrorServerInternalType + // ErrorConflictType conflict. + ErrorConflictType ) // String returns the string representation of the admin problem type, @@ -48,6 +50,8 @@ func (ap ProblemType) String() string { return "unauthorized" case ErrorServerInternalType: return "internalServerError" + case ErrorConflictType: + return "conflict" default: return fmt.Sprintf("unsupported error type '%d'", int(ap)) } @@ -64,7 +68,7 @@ var ( errorServerInternalMetadata = errorMetadata{ typ: ErrorServerInternalType.String(), details: "the server experienced an internal error", - status: 500, + status: http.StatusInternalServerError, } errorMap = map[ProblemType]errorMetadata{ ErrorNotFoundType: { @@ -98,6 +102,11 @@ var ( status: http.StatusUnauthorized, }, ErrorServerInternalType: errorServerInternalMetadata, + ErrorConflictType: { + typ: ErrorConflictType.String(), + details: "conflict", + status: http.StatusConflict, + }, } ) diff --git a/authority/administrator/collection.go b/authority/administrator/collection.go index 88d7bb2c..f40e7417 100644 --- a/authority/administrator/collection.go +++ b/authority/administrator/collection.go @@ -59,12 +59,12 @@ func newSubProv(subject, prov string) subProv { return subProv{subject, prov} } -// LoadBySubProv a admin by the subject and provisioner name. +// LoadBySubProv loads an admin by subject and provisioner name. func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) { return loadAdmin(c.bySubProv, newSubProv(sub, provName)) } -// LoadByProvisioner a admin by the subject and provisioner name. +// LoadByProvisioner loads admins by provisioner name. func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) { val, ok := c.byProv.Load(provName) if !ok { @@ -78,7 +78,7 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool } // Store adds an admin to the collection and enforces the uniqueness of -// admin IDs and amdin subject <-> provisioner name combos. +// admin IDs and admin subject <-> provisioner name combos. func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error { // Input validation. if adm.ProvisionerId != prov.GetID() { diff --git a/authority/authority.go b/authority/authority.go index 9db38e14..8a0013c0 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -12,10 +12,16 @@ import ( "time" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/pemutil" + "go.step.sm/linkedca" + "github.com/smallstep/certificates/authority/admin" adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql" "github.com/smallstep/certificates/authority/administrator" "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" @@ -26,9 +32,6 @@ import ( "github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/templates" "github.com/smallstep/nosql" - "go.step.sm/crypto/pemutil" - "go.step.sm/linkedca" - "golang.org/x/crypto/ssh" ) // Authority implements the Certificate Authority internal interface. @@ -77,6 +80,9 @@ type Authority struct { authorizeRenewFunc provisioner.AuthorizeRenewFunc authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc + // Policy engines + policyEngine *policy.Engine + adminMutex sync.RWMutex } @@ -209,6 +215,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error { a.provisioners = provClxn a.config.AuthorityConfig.Admins = adminList a.admins = adminClxn + return nil } @@ -555,6 +562,11 @@ func (a *Authority) init() error { return err } + // Load x509 and SSH Policy Engines + if err := a.reloadPolicyEngines(context.Background()); err != nil { + return err + } + // Configure templates, currently only ssh templates are supported. if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil { a.templates = a.config.Templates diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 0a1ef53c..087318be 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -491,7 +491,7 @@ func TestAuthority_authorizeSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 8, got) + assert.Equals(t, 9, len(got)) // number of provisioner.SignOptions returned } } }) @@ -1034,7 +1034,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 7, got) + assert.Len(t, 8, got) // number of provisioner.SignOptions returned } } }) diff --git a/authority/config/config.go b/authority/config/config.go index 682321db..c764e8f9 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -8,12 +8,15 @@ import ( "time" "github.com/pkg/errors" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" cas "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" kms "github.com/smallstep/certificates/kms/apiv1" "github.com/smallstep/certificates/templates" - "go.step.sm/linkedca" ) const ( @@ -95,6 +98,7 @@ type AuthConfig struct { Admins []*linkedca.Admin `json:"-"` Template *ASN1DN `json:"template,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"` + Policy *policy.Options `json:"policy,omitempty"` DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` Backdate *provisioner.Duration `json:"backdate,omitempty"` EnableAdmin bool `json:"enableAdmin,omitempty"` diff --git a/authority/linkedca.go b/authority/linkedca.go index 29201164..0552f2d1 100644 --- a/authority/linkedca.go +++ b/authority/linkedca.go @@ -15,16 +15,19 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" + "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/tlsutil" "go.step.sm/crypto/x509util" "go.step.sm/linkedca" - "golang.org/x/crypto/ssh" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" ) const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$" @@ -35,6 +38,9 @@ type linkedCaClient struct { authorityID string } +// interface guard +var _ admin.DB = (*linkedCaClient)(nil) + type linkedCAClaims struct { jose.Claims SANs []string `json:"sans"` @@ -116,6 +122,13 @@ func newLinkedCAClient(token string) (*linkedCaClient, error) { }, nil } +// IsLinkedCA is a sentinel function that can be used to +// check if a linkedCaClient is the underlying type of an +// admin.DB interface. +func (c *linkedCaClient) IsLinkedCA() bool { + return true +} + func (c *linkedCaClient) Run() { c.renewer.Run() } @@ -340,6 +353,22 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) { return resp.Status != linkedca.RevocationStatus_ACTIVE, nil } +func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("not implemented yet") +} + +func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("not implemented yet") +} + +func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("not implemented yet") +} + +func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error { + return errors.New("not implemented yet") +} + func createProvisionerIdentity(prov provisioner.Interface) *linkedca.ProvisionerIdentity { if prov == nil { return nil diff --git a/authority/policy.go b/authority/policy.go new file mode 100644 index 00000000..6348c690 --- /dev/null +++ b/authority/policy.go @@ -0,0 +1,258 @@ +package authority + +import ( + "context" + "errors" + "fmt" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/authority/admin" + authPolicy "github.com/smallstep/certificates/authority/policy" + policy "github.com/smallstep/certificates/policy" +) + +type policyErrorType int + +const ( + AdminLockOut policyErrorType = iota + 1 + StoreFailure + ReloadFailure + ConfigurationFailure + EvaluationFailure + InternalFailure +) + +type PolicyError struct { + Typ policyErrorType + Err error +} + +func (p *PolicyError) Error() string { + return p.Err.Error() +} + +func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + p, err := a.adminDB.GetAuthorityPolicy(ctx) + if err != nil { + return nil, &PolicyError{ + Typ: InternalFailure, + Err: err, + } + } + + return p, nil +} + +func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil { + return nil, err + } + + if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil { + return nil, &PolicyError{ + Typ: StoreFailure, + Err: err, + } + } + + if err := a.reloadPolicyEngines(ctx); err != nil { + return nil, &PolicyError{ + Typ: ReloadFailure, + Err: fmt.Errorf("error reloading policy engines when creating authority policy: %w", err), + } + } + + return p, nil +} + +func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil { + return nil, err + } + + if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil { + return nil, &PolicyError{ + Typ: StoreFailure, + Err: err, + } + } + + if err := a.reloadPolicyEngines(ctx); err != nil { + return nil, &PolicyError{ + Typ: ReloadFailure, + Err: fmt.Errorf("error reloading policy engines when updating authority policy: %w", err), + } + } + + return p, nil +} + +func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil { + return &PolicyError{ + Typ: StoreFailure, + Err: err, + } + } + + if err := a.reloadPolicyEngines(ctx); err != nil { + return &PolicyError{ + Typ: ReloadFailure, + Err: fmt.Errorf("error reloading policy engines when deleting authority policy: %w", err), + } + } + + return nil +} + +func (a *Authority) checkAuthorityPolicy(ctx context.Context, currentAdmin *linkedca.Admin, p *linkedca.Policy) error { + + // no policy and thus nothing to evaluate; return early + if p == nil { + return nil + } + + // get all current admins from the database + allAdmins, err := a.adminDB.GetAdmins(ctx) + if err != nil { + return &PolicyError{ + Typ: InternalFailure, + Err: fmt.Errorf("error retrieving admins: %w", err), + } + } + + return a.checkPolicy(ctx, currentAdmin, allAdmins, p) +} + +func (a *Authority) checkProvisionerPolicy(ctx context.Context, currentAdmin *linkedca.Admin, provName string, p *linkedca.Policy) error { + + // no policy and thus nothing to evaluate; return early + if p == nil { + return nil + } + + // get all admins for the provisioner; ignoring case in which they're not found + allProvisionerAdmins, _ := a.admins.LoadByProvisioner(provName) + + return a.checkPolicy(ctx, currentAdmin, allProvisionerAdmins, p) +} + +// checkPolicy checks if a new or updated policy configuration results in the user +// locking themselves or other admins out of the CA. +func (a *Authority) checkPolicy(ctx context.Context, currentAdmin *linkedca.Admin, otherAdmins []*linkedca.Admin, p *linkedca.Policy) error { + + // convert the policy; return early if nil + policyOptions := authPolicy.LinkedToCertificates(p) + if policyOptions == nil { + return nil + } + + engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options()) + if err != nil { + return &PolicyError{ + Typ: ConfigurationFailure, + Err: err, + } + } + + // when an empty X.509 policy is provided, the resulting engine is nil + // and there's no policy to evaluate. + if engine == nil { + return nil + } + + // TODO(hs): Provide option to force the policy, even when the admin subject would be locked out? + + // check if the admin user that instructed the authority policy to be + // created or updated, would still be allowed when the provided policy + // would be applied. + sans := []string{currentAdmin.GetSubject()} + if err := isAllowed(engine, sans); err != nil { + return err + } + + // loop through admins to verify that none of them would be + // locked out when the new policy were to be applied. Returns + // an error with a message that includes the admin subject that + // would be locked out. + for _, adm := range otherAdmins { + sans = []string{adm.GetSubject()} + if err := isAllowed(engine, sans); err != nil { + return err + } + } + + // TODO(hs): mask the error message for non-super admins? + + return nil +} + +// reloadPolicyEngines reloads x509 and SSH policy engines using +// configuration stored in the DB or from the configuration file. +func (a *Authority) reloadPolicyEngines(ctx context.Context) error { + var ( + err error + policyOptions *authPolicy.Options + ) + + if a.config.AuthorityConfig.EnableAdmin { + + // temporarily disable policy loading when LinkedCA is in use + if _, ok := a.adminDB.(*linkedCaClient); ok { + return nil + } + + linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx) + if err != nil { + var ae *admin.Error + if isAdminError := errors.As(err, &ae); (isAdminError && ae.Type != admin.ErrorNotFoundType.String()) || !isAdminError { + return fmt.Errorf("error getting policy to (re)load policy engines: %w", err) + } + } + policyOptions = authPolicy.LinkedToCertificates(linkedPolicy) + } else { + policyOptions = a.config.AuthorityConfig.Policy + } + + engine, err := authPolicy.New(policyOptions) + if err != nil { + return err + } + + // only update the policy engine when no error was returned + a.policyEngine = engine + + return nil +} + +func isAllowed(engine authPolicy.X509Policy, sans []string) error { + if err := engine.AreSANsAllowed(sans); err != nil { + var policyErr *policy.NamePolicyError + isNamePolicyError := errors.As(err, &policyErr) + if isNamePolicyError && policyErr.Reason == policy.NotAllowed { + return &PolicyError{ + Typ: AdminLockOut, + Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans), + } + } + return &PolicyError{ + Typ: EvaluationFailure, + Err: err, + } + } + + return nil +} diff --git a/authority/policy/engine.go b/authority/policy/engine.go new file mode 100644 index 00000000..4b21f66b --- /dev/null +++ b/authority/policy/engine.go @@ -0,0 +1,114 @@ +package policy + +import ( + "crypto/x509" + "errors" + "fmt" + + "golang.org/x/crypto/ssh" +) + +// Engine is a container for multiple policies. +type Engine struct { + x509Policy X509Policy + sshUserPolicy UserPolicy + sshHostPolicy HostPolicy +} + +// New returns a new Engine using Options. +func New(options *Options) (*Engine, error) { + + // if no options provided, return early + if options == nil { + return nil, nil + } + + var ( + x509Policy X509Policy + sshHostPolicy HostPolicy + sshUserPolicy UserPolicy + err error + ) + + // initialize the x509 allow/deny policy engine + if x509Policy, err = NewX509PolicyEngine(options.GetX509Options()); err != nil { + return nil, err + } + + // initialize the SSH allow/deny policy engine for host certificates + if sshHostPolicy, err = NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil { + return nil, err + } + + // initialize the SSH allow/deny policy engine for user certificates + if sshUserPolicy, err = NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil { + return nil, err + } + + return &Engine{ + x509Policy: x509Policy, + sshHostPolicy: sshHostPolicy, + sshUserPolicy: sshUserPolicy, + }, nil +} + +// IsX509CertificateAllowed evaluates an X.509 certificate against +// the X.509 policy (if available) and returns an error if one of the +// names in the certificate is not allowed. +func (e *Engine) IsX509CertificateAllowed(cert *x509.Certificate) error { + + // return early if there's no policy to evaluate + if e == nil || e.x509Policy == nil { + return nil + } + + // return result of X.509 policy evaluation + return e.x509Policy.IsX509CertificateAllowed(cert) +} + +// AreSANsAllowed evaluates the slice of SANs against the X.509 policy +// (if available) and returns an error if one of the SANs is not allowed. +func (e *Engine) AreSANsAllowed(sans []string) error { + + // return early if there's no policy to evaluate + if e == nil || e.x509Policy == nil { + return nil + } + + // return result of X.509 policy evaluation + return e.x509Policy.AreSANsAllowed(sans) +} + +// IsSSHCertificateAllowed evaluates an SSH certificate against the +// user or host policy (if configured) and returns an error if one of the +// principals in the certificate is not allowed. +func (e *Engine) IsSSHCertificateAllowed(cert *ssh.Certificate) error { + + // return early if there's no policy to evaluate + if e == nil || (e.sshHostPolicy == nil && e.sshUserPolicy == nil) { + return nil + } + + switch cert.CertType { + case ssh.HostCert: + // when no host policy engine is configured, but a user policy engine is + // configured, the host certificate is denied. + if e.sshHostPolicy == nil && e.sshUserPolicy != nil { + return errors.New("authority not allowed to sign ssh host certificates") + } + + // return result of SSH host policy evaluation + return e.sshHostPolicy.IsSSHCertificateAllowed(cert) + case ssh.UserCert: + // when no user policy engine is configured, but a host policy engine is + // configured, the user certificate is denied. + if e.sshUserPolicy == nil && e.sshHostPolicy != nil { + return errors.New("authority not allowed to sign ssh user certificates") + } + + // return result of SSH user policy evaluation + return e.sshUserPolicy.IsSSHCertificateAllowed(cert) + default: + return fmt.Errorf("unexpected ssh certificate type %q", cert.CertType) + } +} diff --git a/authority/policy/options.go b/authority/policy/options.go new file mode 100644 index 00000000..b93d2cd1 --- /dev/null +++ b/authority/policy/options.go @@ -0,0 +1,194 @@ +package policy + +// Options is a container for authority level x509 and SSH +// policy configuration. +type Options struct { + X509 *X509PolicyOptions `json:"x509,omitempty"` + SSH *SSHPolicyOptions `json:"ssh,omitempty"` +} + +// GetX509Options returns the x509 authority level policy +// configuration +func (o *Options) GetX509Options() *X509PolicyOptions { + if o == nil { + return nil + } + return o.X509 +} + +// GetSSHOptions returns the SSH authority level policy +// configuration +func (o *Options) GetSSHOptions() *SSHPolicyOptions { + if o == nil { + return nil + } + return o.SSH +} + +// X509PolicyOptionsInterface is an interface for providers +// of x509 allowed and denied names. +type X509PolicyOptionsInterface interface { + GetAllowedNameOptions() *X509NameOptions + GetDeniedNameOptions() *X509NameOptions + AreWildcardNamesAllowed() bool +} + +// X509PolicyOptions is a container for x509 allowed and denied +// names. +type X509PolicyOptions struct { + // AllowedNames contains the x509 allowed names + AllowedNames *X509NameOptions `json:"allow,omitempty"` + + // DeniedNames contains the x509 denied names + DeniedNames *X509NameOptions `json:"deny,omitempty"` + + // AllowWildcardNames indicates if literal wildcard names + // like *.example.com are allowed. Defaults to false. + AllowWildcardNames bool `json:"allowWildcardNames,omitempty"` +} + +// X509NameOptions models the X509 name policy configuration. +type X509NameOptions struct { + CommonNames []string `json:"cn,omitempty"` + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + URIDomains []string `json:"uri,omitempty"` +} + +// HasNames checks if the AllowedNameOptions has one or more +// names configured. +func (o *X509NameOptions) HasNames() bool { + return len(o.CommonNames) > 0 || + len(o.DNSDomains) > 0 || + len(o.IPRanges) > 0 || + len(o.EmailAddresses) > 0 || + len(o.URIDomains) > 0 +} + +// GetAllowedNameOptions returns x509 allowed name policy configuration +func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the x509 denied name policy configuration +func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +// AreWildcardNamesAllowed returns whether the authority allows +// literal wildcard names to be signed. +func (o *X509PolicyOptions) AreWildcardNamesAllowed() bool { + if o == nil { + return true + } + return o.AllowWildcardNames +} + +// SSHPolicyOptionsInterface is an interface for providers of +// SSH user and host name policy configuration. +type SSHPolicyOptionsInterface interface { + GetAllowedUserNameOptions() *SSHNameOptions + GetDeniedUserNameOptions() *SSHNameOptions + GetAllowedHostNameOptions() *SSHNameOptions + GetDeniedHostNameOptions() *SSHNameOptions +} + +// SSHPolicyOptions is a container for SSH user and host policy +// configuration +type SSHPolicyOptions struct { + // User contains SSH user certificate options. + User *SSHUserCertificateOptions `json:"user,omitempty"` + // Host contains SSH host certificate options. + Host *SSHHostCertificateOptions `json:"host,omitempty"` +} + +// GetAllowedUserNameOptions returns the SSH allowed user name policy +// configuration. +func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions { + if o == nil || o.User == nil { + return nil + } + return o.User.AllowedNames +} + +// GetDeniedUserNameOptions returns the SSH denied user name policy +// configuration. +func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions { + if o == nil || o.User == nil { + return nil + } + return o.User.DeniedNames +} + +// GetAllowedHostNameOptions returns the SSH allowed host name policy +// configuration. +func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions { + if o == nil || o.Host == nil { + return nil + } + return o.Host.AllowedNames +} + +// GetDeniedHostNameOptions returns the SSH denied host name policy +// configuration. +func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions { + if o == nil || o.Host == nil { + return nil + } + return o.Host.DeniedNames +} + +// SSHUserCertificateOptions is a collection of SSH user certificate options. +type SSHUserCertificateOptions struct { + // AllowedNames contains the names the provisioner is authorized to sign + AllowedNames *SSHNameOptions `json:"allow,omitempty"` + // DeniedNames contains the names the provisioner is not authorized to sign + DeniedNames *SSHNameOptions `json:"deny,omitempty"` +} + +// SSHHostCertificateOptions is a collection of SSH host certificate options. +// It's an alias of SSHUserCertificateOptions, as the options are the same +// for both types of certificates. +type SSHHostCertificateOptions SSHUserCertificateOptions + +// SSHNameOptions models the SSH name policy configuration. +type SSHNameOptions struct { + DNSDomains []string `json:"dns,omitempty"` + IPRanges []string `json:"ip,omitempty"` + EmailAddresses []string `json:"email,omitempty"` + Principals []string `json:"principal,omitempty"` +} + +// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the +// names that a provisioner is authorized to sign SSH certificates for. +func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the +// names that a provisioner is NOT authorized to sign SSH certificates for. +func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +// HasNames checks if the SSHNameOptions has one or more +// names configured. +func (o *SSHNameOptions) HasNames() bool { + return len(o.DNSDomains) > 0 || + len(o.IPRanges) > 0 || + len(o.EmailAddresses) > 0 || + len(o.Principals) > 0 +} diff --git a/authority/policy/options_test.go b/authority/policy/options_test.go new file mode 100644 index 00000000..0fd6e7c6 --- /dev/null +++ b/authority/policy/options_test.go @@ -0,0 +1,45 @@ +package policy + +import ( + "testing" +) + +func TestX509PolicyOptions_IsWildcardLiteralAllowed(t *testing.T) { + tests := []struct { + name string + options *X509PolicyOptions + want bool + }{ + { + name: "nil-options", + options: nil, + want: true, + }, + { + name: "not-set", + options: &X509PolicyOptions{}, + want: false, + }, + { + name: "set-true", + options: &X509PolicyOptions{ + AllowWildcardNames: true, + }, + want: true, + }, + { + name: "set-false", + options: &X509PolicyOptions{ + AllowWildcardNames: false, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.options.AreWildcardNamesAllowed(); got != tt.want { + t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/policy/policy.go b/authority/policy/policy.go new file mode 100644 index 00000000..3c53b704 --- /dev/null +++ b/authority/policy/policy.go @@ -0,0 +1,256 @@ +package policy + +import ( + "fmt" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/policy" +) + +// X509Policy is an alias for policy.X509NamePolicyEngine +type X509Policy policy.X509NamePolicyEngine + +// UserPolicy is an alias for policy.SSHNamePolicyEngine +type UserPolicy policy.SSHNamePolicyEngine + +// HostPolicy is an alias for policy.SSHNamePolicyEngine +type HostPolicy policy.SSHNamePolicyEngine + +// NewX509PolicyEngine creates a new x509 name policy engine +func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) { + + // return early if no policy engine options to configure + if policyOptions == nil { + return nil, nil + } + + options := []policy.NamePolicyOption{} + + allowed := policyOptions.GetAllowedNameOptions() + if allowed != nil && allowed.HasNames() { + options = append(options, + policy.WithPermittedCommonNames(allowed.CommonNames...), + policy.WithPermittedDNSDomains(allowed.DNSDomains...), + policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...), + policy.WithPermittedEmailAddresses(allowed.EmailAddresses...), + policy.WithPermittedURIDomains(allowed.URIDomains...), + ) + } + + denied := policyOptions.GetDeniedNameOptions() + if denied != nil && denied.HasNames() { + options = append(options, + policy.WithExcludedCommonNames(denied.CommonNames...), + policy.WithExcludedDNSDomains(denied.DNSDomains...), + policy.WithExcludedIPsOrCIDRs(denied.IPRanges...), + policy.WithExcludedEmailAddresses(denied.EmailAddresses...), + policy.WithExcludedURIDomains(denied.URIDomains...), + ) + } + + // ensure no policy engine is returned when no name options were provided + if len(options) == 0 { + return nil, nil + } + + // check if configuration specifies that wildcard names are allowed + if policyOptions.AreWildcardNamesAllowed() { + options = append(options, policy.WithAllowLiteralWildcardNames()) + } + + // enable subject common name verification by default + options = append(options, policy.WithSubjectCommonNameVerification()) + + return policy.New(options...) +} + +type sshPolicyEngineType string + +const ( + UserPolicyEngineType sshPolicyEngineType = "user" + HostPolicyEngineType sshPolicyEngineType = "host" +) + +// newSSHUserPolicyEngine creates a new SSH user certificate policy engine +func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) { + policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType) + if err != nil { + return nil, err + } + return policyEngine, nil +} + +// newSSHHostPolicyEngine create a new SSH host certificate policy engine +func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) { + policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType) + if err != nil { + return nil, err + } + return policyEngine, nil +} + +// newSSHPolicyEngine creates a new SSH name policy engine +func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) { + + // return early if no policy engine options to configure + if policyOptions == nil { + return nil, nil + } + + var ( + allowed *SSHNameOptions + denied *SSHNameOptions + ) + + switch typ { + case UserPolicyEngineType: + allowed = policyOptions.GetAllowedUserNameOptions() + denied = policyOptions.GetDeniedUserNameOptions() + case HostPolicyEngineType: + allowed = policyOptions.GetAllowedHostNameOptions() + denied = policyOptions.GetDeniedHostNameOptions() + default: + return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ) + } + + options := []policy.NamePolicyOption{} + + if allowed != nil && allowed.HasNames() { + options = append(options, + policy.WithPermittedDNSDomains(allowed.DNSDomains...), + policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...), + policy.WithPermittedEmailAddresses(allowed.EmailAddresses...), + policy.WithPermittedPrincipals(allowed.Principals...), + ) + } + + if denied != nil && denied.HasNames() { + options = append(options, + policy.WithExcludedDNSDomains(denied.DNSDomains...), + policy.WithExcludedIPsOrCIDRs(denied.IPRanges...), + policy.WithExcludedEmailAddresses(denied.EmailAddresses...), + policy.WithExcludedPrincipals(denied.Principals...), + ) + } + + // ensure no policy engine is returned when no name options were provided + if len(options) == 0 { + return nil, nil + } + + return policy.New(options...) +} + +func LinkedToCertificates(p *linkedca.Policy) *Options { + + // return early + if p == nil { + return nil + } + + // return early if x509 nor SSH is set + if p.GetX509() == nil && p.GetSsh() == nil { + return nil + } + + opts := &Options{} + + // fill x509 policy configuration + if x509 := p.GetX509(); x509 != nil { + opts.X509 = &X509PolicyOptions{} + if allow := x509.GetAllow(); allow != nil { + opts.X509.AllowedNames = &X509NameOptions{} + if allow.Dns != nil { + opts.X509.AllowedNames.DNSDomains = allow.Dns + } + if allow.Ips != nil { + opts.X509.AllowedNames.IPRanges = allow.Ips + } + if allow.Emails != nil { + opts.X509.AllowedNames.EmailAddresses = allow.Emails + } + if allow.Uris != nil { + opts.X509.AllowedNames.URIDomains = allow.Uris + } + if allow.CommonNames != nil { + opts.X509.AllowedNames.CommonNames = allow.CommonNames + } + } + if deny := x509.GetDeny(); deny != nil { + opts.X509.DeniedNames = &X509NameOptions{} + if deny.Dns != nil { + opts.X509.DeniedNames.DNSDomains = deny.Dns + } + if deny.Ips != nil { + opts.X509.DeniedNames.IPRanges = deny.Ips + } + if deny.Emails != nil { + opts.X509.DeniedNames.EmailAddresses = deny.Emails + } + if deny.Uris != nil { + opts.X509.DeniedNames.URIDomains = deny.Uris + } + if deny.CommonNames != nil { + opts.X509.DeniedNames.CommonNames = deny.CommonNames + } + } + + opts.X509.AllowWildcardNames = x509.GetAllowWildcardNames() + } + + // fill ssh policy configuration + if ssh := p.GetSsh(); ssh != nil { + opts.SSH = &SSHPolicyOptions{} + if host := ssh.GetHost(); host != nil { + opts.SSH.Host = &SSHHostCertificateOptions{} + if allow := host.GetAllow(); allow != nil { + opts.SSH.Host.AllowedNames = &SSHNameOptions{} + if allow.Dns != nil { + opts.SSH.Host.AllowedNames.DNSDomains = allow.Dns + } + if allow.Ips != nil { + opts.SSH.Host.AllowedNames.IPRanges = allow.Ips + } + if allow.Principals != nil { + opts.SSH.Host.AllowedNames.Principals = allow.Principals + } + } + if deny := host.GetDeny(); deny != nil { + opts.SSH.Host.DeniedNames = &SSHNameOptions{} + if deny.Dns != nil { + opts.SSH.Host.DeniedNames.DNSDomains = deny.Dns + } + if deny.Ips != nil { + opts.SSH.Host.DeniedNames.IPRanges = deny.Ips + } + if deny.Principals != nil { + opts.SSH.Host.DeniedNames.Principals = deny.Principals + } + } + } + if user := ssh.GetUser(); user != nil { + opts.SSH.User = &SSHUserCertificateOptions{} + if allow := user.GetAllow(); allow != nil { + opts.SSH.User.AllowedNames = &SSHNameOptions{} + if allow.Emails != nil { + opts.SSH.User.AllowedNames.EmailAddresses = allow.Emails + } + if allow.Principals != nil { + opts.SSH.User.AllowedNames.Principals = allow.Principals + } + } + if deny := user.GetDeny(); deny != nil { + opts.SSH.User.DeniedNames = &SSHNameOptions{} + if deny.Emails != nil { + opts.SSH.User.DeniedNames.EmailAddresses = deny.Emails + } + if deny.Principals != nil { + opts.SSH.User.DeniedNames.Principals = deny.Principals + } + } + } + } + + return opts +} diff --git a/authority/policy/policy_test.go b/authority/policy/policy_test.go new file mode 100644 index 00000000..9210ad90 --- /dev/null +++ b/authority/policy/policy_test.go @@ -0,0 +1,155 @@ +package policy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "go.step.sm/linkedca" +) + +func TestPolicyToCertificates(t *testing.T) { + type args struct { + policy *linkedca.Policy + } + tests := []struct { + name string + args args + want *Options + }{ + { + name: "nil", + args: args{ + policy: nil, + }, + want: nil, + }, + { + name: "no-policy", + args: args{ + &linkedca.Policy{}, + }, + want: nil, + }, + { + name: "partial-policy", + args: args{ + &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + AllowWildcardNames: false, + }, + }, + }, + want: &Options{ + X509: &X509PolicyOptions{ + AllowedNames: &X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + AllowWildcardNames: false, + }, + }, + }, + { + name: "full-policy", + args: args{ + &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step"}, + Ips: []string{"127.0.0.1/24"}, + Emails: []string{"*.example.com"}, + Uris: []string{"https://*.local"}, + CommonNames: []string{"some name"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"bad"}, + Ips: []string{"127.0.0.30"}, + Emails: []string{"badhost.example.com"}, + Uris: []string{"https://badhost.local"}, + CommonNames: []string{"another name"}, + }, + AllowWildcardNames: true, + }, + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.localhost"}, + Ips: []string{"127.0.0.1/24"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.localhost"}, + Ips: []string{"127.0.0.40"}, + Principals: []string{"root"}, + }, + }, + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@work"}, + Principals: []string{"user"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"root@work"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + want: &Options{ + X509: &X509PolicyOptions{ + AllowedNames: &X509NameOptions{ + DNSDomains: []string{"step"}, + IPRanges: []string{"127.0.0.1/24"}, + EmailAddresses: []string{"*.example.com"}, + URIDomains: []string{"https://*.local"}, + CommonNames: []string{"some name"}, + }, + DeniedNames: &X509NameOptions{ + DNSDomains: []string{"bad"}, + IPRanges: []string{"127.0.0.30"}, + EmailAddresses: []string{"badhost.example.com"}, + URIDomains: []string{"https://badhost.local"}, + CommonNames: []string{"another name"}, + }, + AllowWildcardNames: true, + }, + SSH: &SSHPolicyOptions{ + Host: &SSHHostCertificateOptions{ + AllowedNames: &SSHNameOptions{ + DNSDomains: []string{"*.localhost"}, + IPRanges: []string{"127.0.0.1/24"}, + Principals: []string{"user"}, + }, + DeniedNames: &SSHNameOptions{ + DNSDomains: []string{"badhost.localhost"}, + IPRanges: []string{"127.0.0.40"}, + Principals: []string{"root"}, + }, + }, + User: &SSHUserCertificateOptions{ + AllowedNames: &SSHNameOptions{ + EmailAddresses: []string{"@work"}, + Principals: []string{"user"}, + }, + DeniedNames: &SSHNameOptions{ + EmailAddresses: []string{"root@work"}, + Principals: []string{"root"}, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := LinkedToCertificates(tt.args.policy) + if !cmp.Equal(tt.want, got) { + t.Errorf("policyToCertificates() diff=\n%s", cmp.Diff(tt.want, got)) + } + }) + } +} diff --git a/authority/policy_test.go b/authority/policy_test.go new file mode 100644 index 00000000..efeb743b --- /dev/null +++ b/authority/policy_test.go @@ -0,0 +1,1549 @@ +package authority + +import ( + "context" + "errors" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/administrator" + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" +) + +func TestAuthority_checkPolicy(t *testing.T) { + type test struct { + ctx context.Context + currentAdmin *linkedca.Admin + otherAdmins []*linkedca.Admin + policy *linkedca.Policy + err *PolicyError + } + tests := map[string]func(t *testing.T) test{ + "fail/NewX509PolicyEngine-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"**.local"}, + }, + }, + }, + err: &PolicyError{ + Typ: ConfigurationFailure, + Err: errors.New("cannot parse permitted domain constraint \"**.local\": domain constraint \"**.local\" can only have wildcard as starting character"), + }, + } + }, + "fail/currentAdmin-evaluation-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "*"}, + otherAdmins: []*linkedca.Admin{}, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }, + err: &PolicyError{ + Typ: EvaluationFailure, + Err: errors.New("cannot parse dns domain \"*\""), + }, + } + }, + "fail/currentAdmin-lockout": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{ + { + Subject: "otherAdmin", + }, + }, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }, + err: &PolicyError{ + Typ: AdminLockOut, + Err: errors.New("the provided policy would lock out [step] from the CA. Please update your policy to include [step] as an allowed name"), + }, + } + }, + "fail/otherAdmins-evaluation-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{ + { + Subject: "other", + }, + { + Subject: "**", + }, + }, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "other", "*.local"}, + }, + }, + }, + err: &PolicyError{ + Typ: EvaluationFailure, + Err: errors.New("cannot parse dns domain \"**\""), + }, + } + }, + "fail/otherAdmins-lockout": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{ + { + Subject: "otherAdmin", + }, + }, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step"}, + }, + }, + }, + err: &PolicyError{ + Typ: AdminLockOut, + Err: errors.New("the provided policy would lock out [otherAdmin] from the CA. Please update your policy to include [otherAdmin] as an allowed name"), + }, + } + }, + "ok/no-policy": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{}, + policy: nil, + } + }, + "ok/empty-policy": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{}, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{}, + }, + }, + }, + } + }, + "ok/policy": func(t *testing.T) test { + return test{ + ctx: context.Background(), + currentAdmin: &linkedca.Admin{Subject: "step"}, + otherAdmins: []*linkedca.Admin{ + { + Subject: "otherAdmin", + }, + }, + policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + } + }, + } + + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + a := &Authority{} + + err := a.checkPolicy(tc.ctx, tc.currentAdmin, tc.otherAdmins, tc.policy) + + if tc.err == nil { + assert.Nil(t, err) + } else { + assert.IsType(t, &PolicyError{}, err) + + pe, ok := err.(*PolicyError) + assert.True(t, ok) + + assert.Equal(t, tc.err.Typ, pe.Typ) + assert.Equal(t, tc.err.Error(), pe.Error()) + } + }) + } +} + +func mustPolicyEngine(t *testing.T, options *policy.Options) *policy.Engine { + engine, err := policy.New(options) + if err != nil { + t.Fatal(err) + } + return engine +} + +func TestAuthority_reloadPolicyEngines(t *testing.T) { + + existingPolicyEngine, err := policy.New(&policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.hosts.example.com"}, + }, + }, + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.hosts.example.com"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"@mails.example.com"}, + }, + }, + }, + }) + assert.NoError(t, err) + + newX509Options := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + } + + newSSHHostOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + }, + } + + newSSHUserOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + } + + newSSHOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + } + + newOptions := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + } + + newAdminX509Options := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + } + + newAdminSSHHostOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + }, + } + + newAdminSSHUserOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"@example.com"}, + }, + }, + }, + } + + newAdminOptions := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"@example.com"}, + }, + DeniedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"baduser@example.com"}, + }, + }, + }, + } + + tests := []struct { + name string + config *config.Config + adminDB admin.DB + ctx context.Context + expected *policy.Engine + wantErr bool + }{ + { + name: "fail/standalone-x509-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"**.local"}, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/standalone-ssh-host-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"**.local"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/standalone-ssh-user-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"**example.com"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/adminDB.GetAuthorityPolicy-error", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/admin-x509-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"**.local"}, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/admin-ssh-host-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"**.local"}, + }, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "fail/admin-ssh-user-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@@example.com"}, + }, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: true, + expected: existingPolicyEngine, + }, + { + name: "ok/linkedca-unsupported", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &linkedCaClient{}, + ctx: context.Background(), + wantErr: false, + expected: existingPolicyEngine, + }, + { + name: "ok/standalone-no-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: nil, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, nil), + }, + { + name: "ok/standalone-x509-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newX509Options), + }, + { + name: "ok/standalone-ssh-host-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newSSHHostOptions), + }, + { + name: "ok/standalone-ssh-user-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newSSHUserOptions), + }, + { + name: "ok/standalone-ssh-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newSSHOptions), + }, + { + name: "ok/standalone-full-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: false, + Policy: &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newOptions), + }, + { + name: "ok/admin-x509-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newAdminX509Options), + }, + { + name: "ok/admin-ssh-host-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + }, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newAdminSSHHostOptions), + }, + { + name: "ok/admin-ssh-user-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + Ssh: &linkedca.SSHPolicy{ + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + }, + }, + }, + }, nil + }, + }, + ctx: context.Background(), + wantErr: false, + expected: mustPolicyEngine(t, newAdminSSHUserOptions), + }, + { + name: "ok/admin-full-policy", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + ctx: context.Background(), + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + Ssh: &linkedca.SSHPolicy{ + Host: &linkedca.SSHHostPolicy{ + Allow: &linkedca.SSHHostNames{ + Dns: []string{"*.local"}, + }, + Deny: &linkedca.SSHHostNames{ + Dns: []string{"badhost.local"}, + }, + }, + User: &linkedca.SSHUserPolicy{ + Allow: &linkedca.SSHUserNames{ + Emails: []string{"@example.com"}, + }, + Deny: &linkedca.SSHUserNames{ + Emails: []string{"baduser@example.com"}, + }, + }, + }, + }, nil + }, + }, + wantErr: false, + expected: mustPolicyEngine(t, newAdminOptions), + }, + { + // both DB and JSON config; DB config is taken if Admin API is enabled + name: "ok/admin-over-standalone", + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + Policy: &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + DeniedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"badhost.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"*"}, + }, + DeniedNames: &policy.SSHNameOptions{ + Principals: []string{"root"}, + }, + }, + }, + }, + }, + }, + ctx: context.Background(), + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + Deny: &linkedca.X509Names{ + Dns: []string{"badhost.local"}, + }, + AllowWildcardNames: true, + }, + }, nil + }, + }, + wantErr: false, + expected: mustPolicyEngine(t, newX509Options), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.config, + adminDB: tt.adminDB, + policyEngine: existingPolicyEngine, + } + if err := a.reloadPolicyEngines(tt.ctx); (err != nil) != tt.wantErr { + t.Errorf("Authority.reloadPolicyEngines() error = %v, wantErr %v", err, tt.wantErr) + } + + assert.Equal(t, tt.expected, a.policyEngine) + }) + } +} + +func TestAuthority_checkAuthorityPolicy(t *testing.T) { + type fields struct { + provisioners *provisioner.Collection + admins *administrator.Collection + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + currentAdmin *linkedca.Admin + provName string + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "no policy", + fields: fields{}, + args: args{ + currentAdmin: nil, + provName: "prov", + p: nil, + }, + wantErr: false, + }, + { + name: "fail/adminDB.GetAdmins-error", + fields: fields{ + admins: administrator.NewCollection(nil), + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return nil, errors.New("force") + }, + }, + }, + args: args{ + currentAdmin: &linkedca.Admin{Subject: "step"}, + provName: "prov", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ok", + fields: fields{ + admins: administrator.NewCollection(nil), + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + currentAdmin: &linkedca.Admin{Subject: "step"}, + provName: "prov", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + provisioners: tt.fields.provisioners, + admins: tt.fields.admins, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + if err := a.checkAuthorityPolicy(tt.args.ctx, tt.args.currentAdmin, tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("Authority.checkProvisionerPolicy() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAuthority_checkProvisionerPolicy(t *testing.T) { + type fields struct { + provisioners *provisioner.Collection + admins *administrator.Collection + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + currentAdmin *linkedca.Admin + provName string + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "no policy", + fields: fields{}, + args: args{ + currentAdmin: nil, + provName: "prov", + p: nil, + }, + wantErr: false, + }, + { + name: "ok", + fields: fields{ + admins: administrator.NewCollection(nil), + }, + args: args{ + currentAdmin: &linkedca.Admin{Subject: "step"}, + provName: "prov", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + provisioners: tt.fields.provisioners, + admins: tt.fields.admins, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + if err := a.checkProvisionerPolicy(tt.args.ctx, tt.args.currentAdmin, tt.args.provName, tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("Authority.checkProvisionerPolicy() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAuthority_RemoveAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + wantErr *PolicyError + }{ + { + name: "fail/adminDB.DeleteAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockDeleteAuthorityPolicy: func(ctx context.Context) error { + return errors.New("force") + }, + }, + }, + wantErr: &PolicyError{ + Typ: StoreFailure, + Err: errors.New("force"), + }, + }, + { + name: "fail/a.reloadPolicyEngines", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockDeleteAuthorityPolicy: func(ctx context.Context) error { + return nil + }, + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + }, + }, + wantErr: &PolicyError{ + Typ: ReloadFailure, + Err: errors.New("error reloading policy engines when deleting authority policy: error getting policy to (re)load policy engines: force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockDeleteAuthorityPolicy: func(ctx context.Context) error { + return nil + }, + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, nil + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + err := a.RemoveAuthorityPolicy(tt.args.ctx) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + }) + } +} + +func TestAuthority_GetAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + want *linkedca.Policy + wantErr *PolicyError + }{ + { + name: "fail/adminDB.GetAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + }, + }, + wantErr: &PolicyError{ + Typ: InternalFailure, + Err: errors.New("force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{}, nil + }, + }, + }, + want: &linkedca.Policy{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + got, err := a.GetAuthorityPolicy(tt.args.ctx) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetAuthorityPolicy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_CreateAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + adm *linkedca.Admin + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + want *linkedca.Policy + wantErr *PolicyError + }{ + { + name: "fail/a.checkAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return nil, errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: InternalFailure, + Err: errors.New("error retrieving admins: force"), + }, + }, + { + name: "fail/adminDB.CreateAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + MockCreateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: StoreFailure, + Err: errors.New("force"), + }, + }, + { + name: "fail/a.reloadPolicyEngines", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: ReloadFailure, + Err: errors.New("error reloading policy engines when creating authority policy: error getting policy to (re)load policy engines: force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, nil + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + got, err := a.CreateAuthorityPolicy(tt.args.ctx, tt.args.adm, tt.args.p) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.CreateAuthorityPolicy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_UpdateAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + adm *linkedca.Admin + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + want *linkedca.Policy + wantErr *PolicyError + }{ + { + name: "fail/a.checkAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return nil, errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: InternalFailure, + Err: errors.New("error retrieving admins: force"), + }, + }, + { + name: "fail/adminDB.UpdateAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: StoreFailure, + Err: errors.New("force"), + }, + }, + { + name: "fail/a.reloadPolicyEngines", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: ReloadFailure, + Err: errors.New("error reloading policy engines when updating authority policy: error getting policy to (re)load policy engines: force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { + return nil + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + got, err := a.UpdateAuthorityPolicy(tt.args.ctx, tt.args.adm, tt.args.p) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.UpdateAuthorityPolicy() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index b5d806ab..9374d985 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -3,6 +3,8 @@ package provisioner import ( "context" "crypto/x509" + "fmt" + "net" "time" "github.com/pkg/errors" @@ -23,7 +25,8 @@ type ACME struct { RequireEAB bool `json:"requireEAB,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - ctl *Controller + + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -71,7 +74,7 @@ func (p *ACME) DefaultTLSCertDuration() time.Duration { return p.ctl.Claimer.DefaultTLSCertDuration() } -// Init initializes and validates the fields of a JWK type. +// Init initializes and validates the fields of an ACME type. func (p *ACME) Init(config Config) (err error) { switch { case p.Type == "": @@ -80,15 +83,56 @@ func (p *ACME) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } +// ACMEIdentifierType encodes ACME Identifier types +type ACMEIdentifierType string + +const ( + // IP is the ACME ip identifier type + IP ACMEIdentifierType = "ip" + // DNS is the ACME dns identifier type + DNS ACMEIdentifierType = "dns" +) + +// ACMEIdentifier encodes ACME Order Identifiers +type ACMEIdentifier struct { + Type ACMEIdentifierType + Value string +} + +// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a +// certificate for an ACME Order Identifier. +func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error { + + x509Policy := p.ctl.getPolicy().getX509() + + // identifier is allowed if no policy is configured + if x509Policy == nil { + return nil + } + + // assuming only valid identifiers (IP or DNS) are provided + var err error + switch identifier.Type { + case IP: + err = x509Policy.IsIPAllowed(net.ParseIP(identifier.Value)) + case DNS: + err = x509Policy.IsDNSAllowed(identifier.Value) + default: + err = fmt.Errorf("invalid ACME identifier type '%s' provided", identifier.Type) + } + + return err +} + // AuthorizeSign does not do any validation, because all validation is handled // in the ACME protocol. This method returns a list of modifiers / constraints // on the resulting certificate. func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - return []SignOption{ + opts := []SignOption{ p, // modifiers / withOptions newProvisionerExtensionOption(TypeACME, p.Name, ""), @@ -97,7 +141,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), - }, nil + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), + } + + return opts, nil } // AuthorizeRevoke is called just before the certificate is to be revoked by diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index 1c9a88cc..33cbbc75 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -176,7 +176,7 @@ func TestACME_AuthorizeSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) && assert.NotNil(t, opts) { - assert.Len(t, 6, opts) + assert.Equals(t, 7, len(opts)) // number of SignOptions returned for _, o := range opts { switch v := o.(type) { case *ACME: @@ -193,6 +193,8 @@ func TestACME_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 9d27e016..8433fde5 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -17,10 +17,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // awsIssuer is the string used as issuer in the generated tokens. @@ -421,7 +423,7 @@ func (p *AWS) Init(config Config) (err error) { } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -476,6 +478,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er defaultPublicKeyValidator{}, commonNameValidator(payload.Claims.Subject), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), ), nil } @@ -543,7 +546,7 @@ func (p *AWS) readURL(url string) ([]byte, error) { if err != nil { return nil, err } - return nil, fmt.Errorf("Request for metadata returned non-successful status code %d", + return nil, fmt.Errorf("request for metadata returned non-successful status code %d", resp.StatusCode) } @@ -576,7 +579,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) { } defer resp.Body.Close() if resp.StatusCode >= 400 { - return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode) + return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode) } token, err := io.ReadAll(resp.Body) if err != nil { @@ -754,5 +757,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), ), nil } diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 7027a446..0d9b5d4d 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -642,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1, "foo.local"}, 7, http.StatusOK, false}, - {"ok", p2, args{t2, "instance-id"}, 11, http.StatusOK, false}, - {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 11, http.StatusOK, false}, - {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 11, http.StatusOK, false}, - {"ok", p1, args{t4, "instance-id"}, 7, http.StatusOK, false}, + {"ok", p1, args{t1, "foo.local"}, 8, http.StatusOK, false}, + {"ok", p2, args{t2, "instance-id"}, 12, http.StatusOK, false}, + {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 12, http.StatusOK, false}, + {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 12, http.StatusOK, false}, + {"ok", p1, args{t4, "instance-id"}, 8, http.StatusOK, false}, {"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true}, {"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true}, @@ -673,7 +673,7 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: - assert.Len(t, tt.wantLen, got) + assert.Equals(t, tt.wantLen, len(got)) for _, o := range got { switch v := o.(type) { case *AWS: @@ -699,6 +699,8 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index e6323e9f..438ab5b3 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -13,10 +13,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens. @@ -219,7 +221,7 @@ func (p *Azure) Init(config Config) (err error) { return } - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -360,6 +362,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), ), nil } @@ -425,6 +428,8 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), ), nil } diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index a8a0a271..3e745a5b 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -474,11 +474,11 @@ func TestAzure_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1}, 6, http.StatusOK, false}, - {"ok", p2, args{t2}, 11, http.StatusOK, false}, - {"ok", p1, args{t11}, 6, http.StatusOK, false}, - {"ok", p5, args{t5}, 6, http.StatusOK, false}, - {"ok", p7, args{t7}, 6, http.StatusOK, false}, + {"ok", p1, args{t1}, 7, http.StatusOK, false}, + {"ok", p2, args{t2}, 12, http.StatusOK, false}, + {"ok", p1, args{t11}, 7, http.StatusOK, false}, + {"ok", p5, args{t5}, 7, http.StatusOK, false}, + {"ok", p7, args{t7}, 7, http.StatusOK, false}, {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true}, {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true}, {"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true}, @@ -502,7 +502,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: - assert.Len(t, tt.wantLen, got) + assert.Equals(t, tt.wantLen, len(got)) for _, o := range got { switch v := o.(type) { case *Azure: @@ -528,6 +528,8 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"virtualMachine"}) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go index afd28dcc..0ca40267 100644 --- a/authority/provisioner/controller.go +++ b/authority/provisioner/controller.go @@ -21,14 +21,19 @@ type Controller struct { IdentityFunc GetIdentityFunc AuthorizeRenewFunc AuthorizeRenewFunc AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc + policy *policyEngine } // NewController initializes a new provisioner controller. -func NewController(p Interface, claims *Claims, config Config) (*Controller, error) { +func NewController(p Interface, claims *Claims, config Config, options *Options) (*Controller, error) { claimer, err := NewClaimer(claims, config.Claims) if err != nil { return nil, err } + policy, err := newPolicyEngine(options) + if err != nil { + return nil, err + } return &Controller{ Interface: p, Audiences: &config.Audiences, @@ -36,6 +41,7 @@ func NewController(p Interface, claims *Claims, config Config) (*Controller, err IdentityFunc: config.GetIdentityFunc, AuthorizeRenewFunc: config.AuthorizeRenewFunc, AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc, + policy: policy, }, nil } @@ -192,3 +198,10 @@ func SanitizeSSHUserPrincipal(email string) string { } }, strings.ToLower(email)) } + +func (c *Controller) getPolicy() *policyEngine { + if c == nil { + return nil + } + return c.policy +} diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go index ebd38df1..37cbfd89 100644 --- a/authority/provisioner/controller_test.go +++ b/authority/provisioner/controller_test.go @@ -9,6 +9,8 @@ import ( "time" "golang.org/x/crypto/ssh" + + "github.com/smallstep/certificates/authority/policy" ) var trueValue = true @@ -30,11 +32,40 @@ func mustDuration(t *testing.T, s string) *Duration { return d } +func mustNewPolicyEngine(t *testing.T, options *Options) *policyEngine { + t.Helper() + c, err := newPolicyEngine(options) + if err != nil { + t.Fatal(err) + } + return c +} + func TestNewController(t *testing.T) { + options := &Options{ + X509: &X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + SSH: &SSHOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.local"}, + }, + }, + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + EmailAddresses: []string{"@example.com"}, + }, + }, + }, + } type args struct { - p Interface - claims *Claims - config Config + p Interface + claims *Claims + config Config + options *Options } tests := []struct { name string @@ -45,7 +76,7 @@ func TestNewController(t *testing.T) { {"ok", args{&JWK{}, nil, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, - }}, &Controller{ + }, nil}, &Controller{ Interface: &JWK{}, Audiences: &testAudiences, Claimer: mustClaimer(t, nil, globalProvisionerClaims), @@ -55,24 +86,49 @@ func TestNewController(t *testing.T) { }, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, - }}, &Controller{ + }, nil}, &Controller{ Interface: &JWK{}, Audiences: &testAudiences, Claimer: mustClaimer(t, &Claims{ DisableRenewal: &defaultDisableRenewal, }, globalProvisionerClaims), }, false}, + {"ok with claims and options", args{&JWK{}, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }, options}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, globalProvisionerClaims), + policy: mustNewPolicyEngine(t, options), + }, false}, {"fail claimer", args{&JWK{}, &Claims{ MinTLSDur: mustDuration(t, "24h"), MaxTLSDur: mustDuration(t, "2h"), }, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, + }, nil}, nil, true}, + {"fail options", args{&JWK{}, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }, &Options{ + X509: &X509Options{ + AllowedNames: &policy.X509NameOptions{ + DNSDomains: []string{"**.local"}, + }, + }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewController(tt.args.p, tt.args.claims, tt.args.config) + got, err := NewController(tt.args.p, tt.args.claims, tt.args.config, tt.args.options) if (err != nil) != tt.wantErr { t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index 69d909a2..94c19e17 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -14,10 +14,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // gcpCertsURL is the url that serves Google OAuth2 public keys. @@ -212,7 +214,7 @@ func (p *GCP) Init(config Config) (err error) { } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -270,6 +272,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), ), nil } @@ -432,5 +435,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), ), nil } diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index dfb9a329..3c0bf92e 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -516,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1}, 6, http.StatusOK, false}, - {"ok", p2, args{t2}, 11, http.StatusOK, false}, - {"ok", p3, args{t3}, 6, http.StatusOK, false}, + {"ok", p1, args{t1}, 7, http.StatusOK, false}, + {"ok", p2, args{t2}, 12, http.StatusOK, false}, + {"ok", p3, args{t3}, 7, http.StatusOK, false}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, {"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true}, @@ -545,7 +545,7 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: - assert.Len(t, tt.wantLen, got) + assert.Equals(t, tt.wantLen, len(got)) for _, o := range got { switch v := o.(type) { case *GCP: @@ -571,6 +571,8 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 3c5032fb..336736db 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -7,10 +7,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // jwtPayload extends jwt.Claims with step attributes. @@ -97,7 +99,7 @@ func (p *JWK) Init(config Config) (err error) { return errors.New("provisioner key cannot be empty") } - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -141,6 +143,7 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err // revoke the certificate with serial number in the `sub` property. func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) + // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") } @@ -180,6 +183,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er defaultPublicKeyValidator{}, defaultSANsValidator(claims.SANs), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), }, nil } @@ -188,6 +192,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { + // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew) return p.ctl.AuthorizeRenew(ctx, cert) } @@ -260,11 +265,14 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), ), nil } // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) + // TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 926f9d68..bd8b542b 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -297,7 +297,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } } else { if assert.NotNil(t, got) { - assert.Len(t, 8, got) + assert.Equals(t, 9, len(got)) for _, o := range got { switch v := o.(type) { case *JWK: @@ -317,6 +317,8 @@ func TestJWK_AuthorizeSign(t *testing.T) { assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case defaultSANsValidator: assert.Equals(t, []string(v), tt.sans) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index 083773e0..e2dbf840 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -10,11 +10,13 @@ import ( "net/http" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // NOTE: There can be at most one kubernetes service account provisioner configured @@ -91,6 +93,7 @@ func (p *K8sSA) GetEncryptedKey() (string, string, bool) { // Init initializes and validates the fields of a K8sSA type. func (p *K8sSA) Init(config Config) (err error) { + switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -137,7 +140,7 @@ func (p *K8sSA) Init(config Config) (err error) { p.kauthn = k8s.AuthenticationV1() */ - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -239,6 +242,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), }, nil } @@ -281,6 +285,8 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), ), nil } diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index e98b6f48..1eff379d 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -280,7 +280,6 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - tot := 0 for _, o := range opts { switch v := o.(type) { case *K8sSA: @@ -296,12 +295,13 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } - tot++ } - assert.Equals(t, tot, 6) + assert.Equals(t, 7, len(opts)) } } } @@ -368,7 +368,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - tot := 0 + assert.Len(t, 7, opts) for _, o := range opts { switch v := o.(type) { case sshCertificateOptionsFunc: @@ -380,12 +380,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case *sshCertDefaultValidator: case *sshDefaultDuration: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) + case *sshNamePolicyValidator: + assert.Equals(t, nil, v.userPolicyEngine) + assert.Equals(t, nil, v.hostPolicyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } - tot++ } - assert.Equals(t, tot, 6) } } } diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index 4216e997..38a2409f 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -10,12 +10,14 @@ import ( "github.com/pkg/errors" nebula "github.com/slackhq/nebula/cert" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x25519" "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" + + "github.com/smallstep/certificates/errs" ) const ( @@ -61,7 +63,7 @@ func (p *Nebula) Init(config Config) (err error) { } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -161,6 +163,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, }, defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), }, nil } @@ -256,6 +259,8 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), ), nil } diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 3a9398a2..9f389b29 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -12,10 +12,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // openIDConfiguration contains the necessary properties in the @@ -195,7 +197,7 @@ func (o *OIDC) Init(config Config) (err error) { return err } - o.ctl, err = NewController(o, o.Claims, config) + o.ctl, err = NewController(o, o.Claims, config, o.Options) return } @@ -353,6 +355,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // validators defaultPublicKeyValidator{}, newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(o.ctl.getPolicy().getX509()), }, nil } @@ -439,6 +442,8 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption &sshCertValidityValidator{o.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(o.ctl.getPolicy().getSSHHost(), o.ctl.getPolicy().getSSHUser()), ), nil } diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 548c4dc8..3d039496 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -323,7 +323,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) { assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { - assert.Len(t, 6, got) + assert.Equals(t, 7, len(got)) for _, o := range got { switch v := o.(type) { case *OIDC: @@ -341,6 +341,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) { assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case emailOnlyIdentity: assert.Equals(t, string(v), "name@smallstep.com") + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } diff --git a/authority/provisioner/options.go b/authority/provisioner/options.go index f86c4863..f5c919b4 100644 --- a/authority/provisioner/options.go +++ b/authority/provisioner/options.go @@ -5,8 +5,11 @@ import ( "strings" "github.com/pkg/errors" + "go.step.sm/crypto/jose" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/authority/policy" ) // CertificateOptions is an interface that returns a list of options passed when @@ -56,6 +59,16 @@ type X509Options struct { // TemplateData is a JSON object with variables that can be used in custom // templates. TemplateData json.RawMessage `json:"templateData,omitempty"` + + // AllowedNames contains the SANs the provisioner is authorized to sign + AllowedNames *policy.X509NameOptions `json:"-"` + + // DeniedNames contains the SANs the provisioner is not authorized to sign + DeniedNames *policy.X509NameOptions `json:"-"` + + // AllowWildcardNames indicates if literal wildcard names + // like *.example.com are allowed. Defaults to false. + AllowWildcardNames bool `json:"-"` } // HasTemplate returns true if a template is defined in the provisioner options. @@ -63,6 +76,31 @@ func (o *X509Options) HasTemplate() bool { return o != nil && (o.Template != "" || o.TemplateFile != "") } +// GetAllowedNameOptions returns the AllowedNames, which models the +// SANs that a provisioner is authorized to sign x509 certificates for. +func (o *X509Options) GetAllowedNameOptions() *policy.X509NameOptions { + if o == nil { + return nil + } + return o.AllowedNames +} + +// GetDeniedNameOptions returns the DeniedNames, which models the +// SANs that a provisioner is NOT authorized to sign x509 certificates for. +func (o *X509Options) GetDeniedNameOptions() *policy.X509NameOptions { + if o == nil { + return nil + } + return o.DeniedNames +} + +func (o *X509Options) AreWildcardNamesAllowed() bool { + if o == nil { + return true + } + return o.AllowWildcardNames +} + // TemplateOptions generates a CertificateOptions with the template and data // defined in the ProvisionerOptions, the provisioner generated data, and the // user data provided in the request. If no template has been provided, diff --git a/authority/provisioner/options_test.go b/authority/provisioner/options_test.go index 8f411aca..0bcf9ec3 100644 --- a/authority/provisioner/options_test.go +++ b/authority/provisioner/options_test.go @@ -287,3 +287,38 @@ func Test_unsafeParseSigned(t *testing.T) { }) } } + +func TestX509Options_IsWildcardLiteralAllowed(t *testing.T) { + tests := []struct { + name string + options *X509Options + want bool + }{ + { + name: "nil-options", + options: nil, + want: true, + }, + { + name: "set-true", + options: &X509Options{ + AllowWildcardNames: true, + }, + want: true, + }, + { + name: "set-false", + options: &X509Options{ + AllowWildcardNames: false, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.options.AreWildcardNamesAllowed(); got != tt.want { + t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/policy.go b/authority/provisioner/policy.go new file mode 100644 index 00000000..95ef4163 --- /dev/null +++ b/authority/provisioner/policy.go @@ -0,0 +1,65 @@ +package provisioner + +import "github.com/smallstep/certificates/authority/policy" + +type policyEngine struct { + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy + sshUserPolicy policy.UserPolicy +} + +func newPolicyEngine(options *Options) (*policyEngine, error) { + + if options == nil { + return nil, nil + } + + var ( + x509Policy policy.X509Policy + sshHostPolicy policy.HostPolicy + sshUserPolicy policy.UserPolicy + err error + ) + + // Initialize the x509 allow/deny policy engine + if x509Policy, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil { + return nil, err + } + + // Initialize the SSH allow/deny policy engine for host certificates + if sshHostPolicy, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil { + return nil, err + } + + // Initialize the SSH allow/deny policy engine for user certificates + if sshUserPolicy, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil { + return nil, err + } + + return &policyEngine{ + x509Policy: x509Policy, + sshHostPolicy: sshHostPolicy, + sshUserPolicy: sshUserPolicy, + }, nil +} + +func (p *policyEngine) getX509() policy.X509Policy { + if p == nil { + return nil + } + return p.x509Policy +} + +func (p *policyEngine) getSSHHost() policy.HostPolicy { + if p == nil { + return nil + } + return p.sshHostPolicy +} + +func (p *policyEngine) getSSHUser() policy.UserPolicy { + if p == nil { + return nil + } + return p.sshUserPolicy +} diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 9dc1edd8..c49c993e 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -28,13 +28,12 @@ type SCEP struct { // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // Defaults to 0, being DES-CBC - EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` - - Options *Options `json:"options,omitempty"` - Claims *Claims `json:"claims,omitempty"` - secretChallengePassword string - encryptionAlgorithm int - ctl *Controller + EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` + ctl *Controller + secretChallengePassword string + encryptionAlgorithm int } // GetID returns the provisioner unique identifier. @@ -84,7 +83,6 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration { // Init initializes and validates the fields of a SCEP type. func (s *SCEP) Init(config Config) (err error) { - switch { case s.Type == "": return errors.New("provisioner type cannot be empty") @@ -112,7 +110,7 @@ func (s *SCEP) Init(config Config) (err error) { // TODO: add other, SCEP specific, options? - s.ctl, err = NewController(s, s.Claims, config) + s.ctl, err = NewController(s, s.Claims, config, s.Options) return } @@ -129,6 +127,7 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // validators newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(s.ctl.getPolicy().getX509()), }, nil } diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 80dfc66e..2eefd331 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -13,9 +13,11 @@ import ( "reflect" "time" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/errs" ) // DefaultCertValidity is the default validity for a certificate if none is specified. @@ -402,6 +404,32 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { return nil } +// x509NamePolicyValidator validates that the certificate (to be signed) +// contains only allowed SANs. +type x509NamePolicyValidator struct { + policyEngine policy.X509Policy +} + +// newX509NamePolicyValidator return a new SANs allow/deny validator. +func newX509NamePolicyValidator(engine policy.X509Policy) *x509NamePolicyValidator { + return &x509NamePolicyValidator{ + policyEngine: engine, + } +} + +// Valid validates that the certificate (to be signed) contains only allowed SANs. +func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) error { + if v.policyEngine == nil { + return nil + } + return v.policyEngine.IsX509CertificateAllowed(cert) +} + +// var ( +// stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} +// stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) +// ) + // type stepProvisionerASN1 struct { // Type int // Name []byte diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index a2ca78b1..70dffba2 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -4,11 +4,13 @@ import ( "crypto/rsa" "encoding/binary" "encoding/json" + "fmt" "math/big" "strings" "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" @@ -444,6 +446,53 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOpti } } +// sshNamePolicyValidator validates that the certificate (to be signed) +// contains only allowed principals. +type sshNamePolicyValidator struct { + hostPolicyEngine policy.HostPolicy + userPolicyEngine policy.UserPolicy +} + +// newSSHNamePolicyValidator return a new SSH allow/deny validator. +func newSSHNamePolicyValidator(host policy.HostPolicy, user policy.UserPolicy) *sshNamePolicyValidator { + return &sshNamePolicyValidator{ + hostPolicyEngine: host, + userPolicyEngine: user, + } +} + +// Valid validates that the certificate (to be signed) contains only allowed principals. +func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { + if v.hostPolicyEngine == nil && v.userPolicyEngine == nil { + // no policy configured at all; allow anything + return nil + } + + // Check the policy type to execute based on type of the certificate. + // We don't allow user certs if only a host policy engine is configured and + // the same for host certs: if only a user policy engine is configured, host + // certs are denied. When both policy engines are configured, the type of + // cert determines which policy engine is used. + switch cert.CertType { + case ssh.HostCert: + // when no host policy engine is configured, but a user policy engine is + // configured, the host certificate is denied. + if v.hostPolicyEngine == nil && v.userPolicyEngine != nil { + return errors.New("SSH host certificate not authorized") + } + return v.hostPolicyEngine.IsSSHCertificateAllowed(cert) + case ssh.UserCert: + // when no user policy engine is configured, but a host policy engine is + // configured, the user certificate is denied. + if v.userPolicyEngine == nil && v.hostPolicyEngine != nil { + return errors.New("SSH user certificate not authorized") + } + return v.userPolicyEngine.IsSSHCertificateAllowed(cert) + default: + return fmt.Errorf("unexpected SSH certificate type %d", cert.CertType) // satisfy return; shouldn't happen + } +} + // sshCertTypeUInt32 func sshCertTypeUInt32(ct string) uint32 { switch ct { diff --git a/authority/provisioner/ssh_options.go b/authority/provisioner/ssh_options.go index 7ee236d1..93633a21 100644 --- a/authority/provisioner/ssh_options.go +++ b/authority/provisioner/ssh_options.go @@ -6,6 +6,8 @@ import ( "github.com/pkg/errors" "go.step.sm/crypto/sshutil" + + "github.com/smallstep/certificates/authority/policy" ) // SSHCertificateOptions is an interface that returns a list of options passed when @@ -33,6 +35,60 @@ type SSHOptions struct { // TemplateData is a JSON object with variables that can be used in custom // templates. TemplateData json.RawMessage `json:"templateData,omitempty"` + + // User contains SSH user certificate options. + User *policy.SSHUserCertificateOptions `json:"-"` + + // Host contains SSH host certificate options. + Host *policy.SSHHostCertificateOptions `json:"-"` +} + +// GetAllowedUserNameOptions returns the SSHNameOptions that are +// allowed when SSH User certificates are requested. +func (o *SSHOptions) GetAllowedUserNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.User == nil { + return nil + } + return o.User.AllowedNames +} + +// GetDeniedUserNameOptions returns the SSHNameOptions that are +// denied when SSH user certificates are requested. +func (o *SSHOptions) GetDeniedUserNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.User == nil { + return nil + } + return o.User.DeniedNames +} + +// GetAllowedHostNameOptions returns the SSHNameOptions that are +// allowed when SSH host certificates are requested. +func (o *SSHOptions) GetAllowedHostNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.Host == nil { + return nil + } + return o.Host.AllowedNames +} + +// GetDeniedHostNameOptions returns the SSHNameOptions that are +// denied when SSH host certificates are requested. +func (o *SSHOptions) GetDeniedHostNameOptions() *policy.SSHNameOptions { + if o == nil { + return nil + } + if o.Host == nil { + return nil + } + return o.Host.DeniedNames } // HasTemplate returns true if a template is defined in the provisioner options. diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 9de0fca2..c3a1a639 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -8,9 +8,11 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" - "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/jose" + + "github.com/smallstep/certificates/errs" ) // sshPOPPayload extends jwt.Claims with step attributes. @@ -95,7 +97,7 @@ func (p *SSHPOP) Init(config Config) (err error) { p.sshPubKeys = config.SSHKeys config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, nil) return } diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 3d032ea0..0a1d176c 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -184,7 +184,7 @@ func generateJWK() (*JWK, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -219,7 +219,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -256,7 +256,7 @@ func generateSSHPOP() (*SSHPOP, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -305,7 +305,7 @@ M46l92gdOozT } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -343,7 +343,7 @@ func generateOIDC() (*OIDC, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } @@ -373,7 +373,7 @@ func generateGCP() (*GCP, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences.WithFragment("gcp/" + name), - }) + }, nil) return p, err } @@ -411,7 +411,7 @@ func generateAWS() (*AWS, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences.WithFragment("aws/" + name), - }) + }, nil) return p, err } @@ -518,7 +518,7 @@ func generateAWSV1Only() (*AWS, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences.WithFragment("aws/" + name), - }) + }, nil) return p, err } @@ -608,7 +608,7 @@ func generateAzure() (*Azure, error) { } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, - }) + }, nil) return p, err } diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 295d81fb..69576da5 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -8,10 +8,12 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/errs" ) // x5cPayload extends jwt.Claims with step attributes. @@ -121,7 +123,7 @@ func (p *X5C) Init(config Config) (err error) { } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) - p.ctl, err = NewController(p, p.Claims, config) + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } @@ -233,6 +235,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er defaultSANsValidator(claims.SANs), defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), }, nil } @@ -317,5 +320,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, + // Ensure that all principal names are allowed + newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), ), nil } diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index a3308f00..f28fcc7c 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -468,7 +468,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - assert.Equals(t, len(opts), 8) + assert.Equals(t, 9, len(opts)) for _, o := range opts { switch v := o.(type) { case *X5C: @@ -480,7 +480,6 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.Len(t, 0, v.KeyValuePairs) case profileLimitDuration: assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration()) - claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign) assert.FatalError(t, err) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) @@ -492,6 +491,8 @@ func TestX5C_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equals(t, nil, v.policyEngine) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } @@ -788,6 +789,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) + case *sshNamePolicyValidator: + assert.Equals(t, nil, v.userPolicyEngine) + assert.Equals(t, nil, v.hostPolicyEngine) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) @@ -795,9 +799,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { tot++ } if len(tc.claims.Step.SSH.CertType) > 0 { - assert.Equals(t, tot, 9) + assert.Equals(t, tot, 10) } else { - assert.Equals(t, tot, 7) + assert.Equals(t, tot, 8) } } } diff --git a/authority/provisioners.go b/authority/provisioners.go index 63fb630b..cde4a6e9 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -10,16 +10,19 @@ import ( "os" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/authority/config" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" + "gopkg.in/square/go-jose.v2/jwt" + "go.step.sm/cli-utils/step" "go.step.sm/cli-utils/ui" "go.step.sm/crypto/jose" "go.step.sm/linkedca" - "gopkg.in/square/go-jose.v2/jwt" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" ) // GetEncryptedKey returns the JWE key corresponding to the given kid argument. @@ -170,6 +173,12 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi return admin.WrapErrorISE(err, "error generating provisioner config") } + adm := linkedca.MustAdminFromContext(ctx) + + if err := a.checkProvisionerPolicy(ctx, adm, prov.Name, prov.Policy); err != nil { + return err + } + if err := certProv.Init(provisionerConfig); err != nil { return admin.WrapError(admin.ErrorBadRequestType, err, "error validating configuration for provisioner %s", prov.Name) } @@ -215,6 +224,12 @@ func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisio return admin.WrapErrorISE(err, "error generating provisioner config") } + adm := linkedca.MustAdminFromContext(ctx) + + if err := a.checkProvisionerPolicy(ctx, adm, nu.Name, nu.Policy); err != nil { + return err + } + if err := certProv.Init(provisionerConfig); err != nil { return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name) } @@ -427,6 +442,60 @@ func optionsToCertificates(p *linkedca.Provisioner) *provisioner.Options { ops.SSH.Template = string(p.SshTemplate.Template) ops.SSH.TemplateData = p.SshTemplate.Data } + if pol := p.GetPolicy(); pol != nil { + if x := pol.GetX509(); x != nil { + if allow := x.GetAllow(); allow != nil { + ops.X509.AllowedNames = &policy.X509NameOptions{ + DNSDomains: allow.Dns, + IPRanges: allow.Ips, + EmailAddresses: allow.Emails, + URIDomains: allow.Uris, + } + } + if deny := x.GetDeny(); deny != nil { + ops.X509.DeniedNames = &policy.X509NameOptions{ + DNSDomains: deny.Dns, + IPRanges: deny.Ips, + EmailAddresses: deny.Emails, + URIDomains: deny.Uris, + } + } + } + if ssh := pol.GetSsh(); ssh != nil { + if host := ssh.GetHost(); host != nil { + ops.SSH.Host = &policy.SSHHostCertificateOptions{} + if allow := host.GetAllow(); allow != nil { + ops.SSH.Host.AllowedNames = &policy.SSHNameOptions{ + DNSDomains: allow.Dns, + IPRanges: allow.Ips, + Principals: allow.Principals, + } + } + if deny := host.GetDeny(); deny != nil { + ops.SSH.Host.DeniedNames = &policy.SSHNameOptions{ + DNSDomains: deny.Dns, + IPRanges: deny.Ips, + Principals: deny.Principals, + } + } + } + if user := ssh.GetUser(); user != nil { + ops.SSH.User = &policy.SSHUserCertificateOptions{} + if allow := user.GetAllow(); allow != nil { + ops.SSH.User.AllowedNames = &policy.SSHNameOptions{ + EmailAddresses: allow.Emails, + Principals: allow.Principals, + } + } + if deny := user.GetDeny(); deny != nil { + ops.SSH.User.DeniedNames = &policy.SSHNameOptions{ + EmailAddresses: deny.Emails, + Principals: deny.Principals, + } + } + } + } + } return ops } diff --git a/authority/ssh.go b/authority/ssh.go index 4a67b28c..0521ab58 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -5,18 +5,23 @@ import ( "crypto/rand" "crypto/x509" "encoding/binary" + "errors" + "fmt" "net/http" "strings" "time" + "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/randutil" + "go.step.sm/crypto/sshutil" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" + policy "github.com/smallstep/certificates/policy" "github.com/smallstep/certificates/templates" - "go.step.sm/crypto/randutil" - "go.step.sm/crypto/sshutil" - "golang.org/x/crypto/ssh" ) const ( @@ -241,6 +246,23 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType) } + // Check if authority is allowed to sign the certificate + if err := a.isAllowedToSignSSHCertificate(certTpl); err != nil { + var pe *policy.NamePolicyError + if errors.As(err, &pe) && pe.Reason == policy.NotAllowed { + return nil, &errs.Error{ + // NOTE: custom forbidden error, so that denied name is sent to client + // as well as shown in the logs. + Status: http.StatusForbidden, + Err: fmt.Errorf("authority not allowed to sign: %w", err), + Msg: fmt.Sprintf("The request was forbidden by the certificate authority: %s", err.Error()), + } + } + return nil, errs.InternalServerErr(err, + errs.WithMessage("authority.SignSSH: error creating ssh certificate"), + ) + } + // Sign certificate. cert, err := sshutil.CreateCertificate(certTpl, signer) if err != nil { @@ -261,6 +283,11 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi return cert, nil } +// isAllowedToSignSSHCertificate checks if the Authority is allowed to sign the SSH certificate. +func (a *Authority) isAllowedToSignSSHCertificate(cert *ssh.Certificate) error { + return a.policyEngine.IsSSHCertificateAllowed(cert) +} + // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { diff --git a/authority/ssh_test.go b/authority/ssh_test.go index ce840fe1..4fd7eaa0 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -20,6 +20,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/templates" @@ -159,6 +160,14 @@ func TestAuthority_SignSSH(t *testing.T) { assert.FatalError(t, err) hostTemplateWithHosts, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"foo.test.com", "bar.test.com"})) assert.FatalError(t, err) + userTemplateWithRoot, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"root"})) + assert.FatalError(t, err) + hostTemplateWithExampleDotCom, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"example.com"})) + assert.FatalError(t, err) + badUserTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"127.0.0.1"})) + assert.FatalError(t, err) + badHostTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"host...local"})) + assert.FatalError(t, err) userCustomTemplate, err := provisioner.TemplateSSHOptions(&provisioner.Options{ SSH: &provisioner.SSHOptions{Template: `{ "type": "{{ .Type }}", @@ -182,11 +191,36 @@ func TestAuthority_SignSSH(t *testing.T) { }, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"user"})) assert.FatalError(t, err) + userPolicyOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + User: &policy.SSHUserCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + Principals: []string{"user"}, + }, + }, + }, + } + userPolicy, err := policy.New(userPolicyOptions) + assert.FatalError(t, err) + + hostPolicyOptions := &policy.Options{ + SSH: &policy.SSHPolicyOptions{ + Host: &policy.SSHHostCertificateOptions{ + AllowedNames: &policy.SSHNameOptions{ + DNSDomains: []string{"*.test.com"}, + }, + }, + }, + } + hostPolicy, err := policy.New(hostPolicyOptions) + assert.FatalError(t, err) + now := time.Now() type fields struct { sshCAUserCertSignKey ssh.Signer sshCAHostCertSignKey ssh.Signer + policyEngine *policy.Engine } type args struct { key ssh.PublicKey @@ -206,39 +240,48 @@ func TestAuthority_SignSSH(t *testing.T) { want want wantErr bool }{ - {"ok-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, - {"ok-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, - {"ok-user-only", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, - {"ok-host-only", fields{nil, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, - {"ok-opts-type-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false}, - {"ok-opts-type-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false}, - {"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, - {"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, - {"ok-opts-valid-after", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false}, - {"ok-opts-valid-before", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false}, - {"ok-cert-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false}, - {"ok-cert-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false}, - {"ok-opts-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false}, - {"ok-opts-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false}, - {"ok-custom-template", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userCustomTemplate, userOptions}}, want{CertType: ssh.UserCert, Principals: []string{"user", "admin"}}, false}, - {"fail-opts-type", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "foo"}, []provisioner.SignOption{userTemplate}}, want{}, true}, - {"fail-cert-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("an error")}}, want{}, true}, - {"fail-cert-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("an error")}}, want{}, true}, - {"fail-opts-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("an error")}}, want{}, true}, - {"fail-opts-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("an error")}}, want{}, true}, - {"fail-bad-sign-options", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, "wrong type"}}, want{}, true}, - {"fail-no-user-key", fields{nil, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{}, true}, - {"fail-no-host-key", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{}, true}, - {"fail-bad-type", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, sshTestModifier{CertType: 100}}}, want{}, true}, - {"fail-custom-template", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userFailTemplate, userOptions}}, want{}, true}, - {"fail-custom-template-syntax-error-file", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONSyntaxErrorTemplateFile, userOptions}}, want{}, true}, - {"fail-custom-template-syntax-value-file", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONValueErrorTemplateFile, userOptions}}, want{}, true}, + {"ok-user", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, + {"ok-host", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, + {"ok-user-only", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, + {"ok-host-only", fields{nil, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, + {"ok-opts-type-user", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false}, + {"ok-opts-type-host", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false}, + {"ok-opts-principals", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, + {"ok-opts-principals", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, + {"ok-opts-valid-after", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false}, + {"ok-opts-valid-before", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false}, + {"ok-cert-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-cert-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-opts-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-opts-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false}, + {"ok-custom-template", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userCustomTemplate, userOptions}}, want{CertType: ssh.UserCert, Principals: []string{"user", "admin"}}, false}, + {"ok-user-policy", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, + {"ok-host-policy", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, + {"fail-opts-type", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "foo"}, []provisioner.SignOption{userTemplate}}, want{}, true}, + {"fail-cert-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("an error")}}, want{}, true}, + {"fail-cert-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("an error")}}, want{}, true}, + {"fail-opts-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("an error")}}, want{}, true}, + {"fail-opts-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("an error")}}, want{}, true}, + {"fail-bad-sign-options", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, "wrong type"}}, want{}, true}, + {"fail-no-user-key", fields{nil, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{}, true}, + {"fail-no-host-key", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{}, true}, + {"fail-bad-type", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, sshTestModifier{CertType: 100}}}, want{}, true}, + {"fail-custom-template", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userFailTemplate, userOptions}}, want{}, true}, + {"fail-custom-template-syntax-error-file", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONSyntaxErrorTemplateFile, userOptions}}, want{}, true}, + {"fail-custom-template-syntax-value-file", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONValueErrorTemplateFile, userOptions}}, want{}, true}, + {"fail-user-policy", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"root"}}, []provisioner.SignOption{userTemplateWithRoot}}, want{}, true}, + {"fail-user-policy-with-host-cert", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com"}}, []provisioner.SignOption{hostTemplateWithExampleDotCom}}, want{}, true}, + {"fail-user-policy-with-bad-user", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{badUserTemplate}}, want{}, true}, + {"fail-host-policy", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"example.com"}}, []provisioner.SignOption{hostTemplateWithExampleDotCom}}, want{}, true}, + {"fail-host-policy-with-user-cert", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{}, true}, + {"fail-host-policy-with-bad-host", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"example.com"}}, []provisioner.SignOption{badHostTemplate}}, want{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey + a.policyEngine = tt.fields.policyEngine got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...) if (err != nil) != tt.wantErr { diff --git a/authority/tls.go b/authority/tls.go index 631f7937..d23b0da7 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -16,16 +16,19 @@ import ( "time" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/x509util" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" - "go.step.sm/crypto/pemutil" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" + "github.com/smallstep/certificates/policy" ) // GetTLSOptions returns the tls options configured. @@ -196,6 +199,25 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } } + // Check if authority is allowed to sign the certificate + if err := a.isAllowedToSignX509Certificate(leaf); err != nil { + var pe *policy.NamePolicyError + if errors.As(err, &pe) && pe.Reason == policy.NotAllowed { + return nil, errs.ApplyOptions(&errs.Error{ + // NOTE: custom forbidden error, so that denied name is sent to client + // as well as shown in the logs. + Status: http.StatusForbidden, + Err: fmt.Errorf("authority not allowed to sign: %w", err), + Msg: fmt.Sprintf("The request was forbidden by the certificate authority: %s", err.Error()), + }, opts...) + } + return nil, errs.InternalServerErr(err, + errs.WithKeyVal("csr", csr), + errs.WithKeyVal("signOptions", signOpts), + errs.WithMessage("error creating certificate"), + ) + } + // Sign certificate lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate)) resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{ @@ -219,6 +241,18 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign return fullchain, nil } +// isAllowedToSignX509Certificate checks if the Authority is allowed +// to sign the X.509 certificate. +func (a *Authority) isAllowedToSignX509Certificate(cert *x509.Certificate) error { + return a.policyEngine.IsX509CertificateAllowed(cert) +} + +// AreSANsAllowed evaluates the provided sans against the +// authority X.509 policy. +func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error { + return a.policyEngine.AreSANsAllowed(sans) +} + // Renew creates a new Certificate identical to the old certificate, except // with a validity window that begins 'now'. func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { diff --git a/authority/tls_test.go b/authority/tls_test.go index e199e0c5..9330f0a3 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -27,6 +27,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/cas/softcas" "github.com/smallstep/certificates/db" @@ -511,6 +512,39 @@ ZYtQ9Ot36qc= code: http.StatusForbidden, } }, + "fail with policy": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + aa := testAuthority(t) + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + aa.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { + fmt.Println(crt.Subject) + assert.Equals(t, crt.Subject.CommonName, "smallstep test") + return nil + }, + } + options := &policy.Options{ + X509: &policy.X509PolicyOptions{ + DeniedNames: &policy.X509NameOptions{ + DNSDomains: []string{"test.smallstep.com"}, + }, + }, + } + engine, err := policy.New(options) + assert.FatalError(t, err) + aa.policyEngine = engine + return &signTest{ + auth: aa, + csr: csr, + extraOpts: extraOpts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + extensionsCount: 6, + err: errors.New("authority not allowed to sign"), + code: http.StatusForbidden, + } + }, "ok": func(t *testing.T) *signTest { csr := getCSR(t, priv) _a := testAuthority(t) @@ -653,6 +687,38 @@ ZYtQ9Ot36qc= extensionsCount: 7, } }, + "ok with policy": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + aa := testAuthority(t) + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + aa.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { + fmt.Println(crt.Subject) + assert.Equals(t, crt.Subject.CommonName, "smallstep test") + return nil + }, + } + options := &policy.Options{ + X509: &policy.X509PolicyOptions{ + AllowedNames: &policy.X509NameOptions{ + CommonNames: []string{"smallstep test"}, + DNSDomains: []string{"*.smallstep.com"}, + }, + }, + } + engine, err := policy.New(options) + assert.FatalError(t, err) + aa.policyEngine = engine + return &signTest{ + auth: aa, + csr: csr, + extraOpts: extraOpts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + extensionsCount: 6, + } + }, } for name, genTestCase := range tests { diff --git a/ca/adminClient.go b/ca/adminClient.go index 72f62dd8..bf853e9d 100644 --- a/ca/adminClient.go +++ b/ca/adminClient.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/x509" "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -12,15 +13,17 @@ import ( "time" "github.com/pkg/errors" - adminAPI "github.com/smallstep/certificates/authority/admin/api" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/errs" + "google.golang.org/protobuf/encoding/protojson" + "go.step.sm/cli-utils/token" "go.step.sm/cli-utils/token/provision" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" "go.step.sm/linkedca" - "google.golang.org/protobuf/encoding/protojson" + + adminAPI "github.com/smallstep/certificates/authority/admin/api" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" ) const ( @@ -687,6 +690,418 @@ retry: return nil } +func (c *AdminClient) GetAuthorityPolicy() (*linkedca.Policy, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) + if err != nil { + return nil, fmt.Errorf("creating GET %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client GET %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) CreateAuthorityPolicy(p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client POST %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) UpdateAuthorityPolicy(p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client PUT %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) RemoveAuthorityPolicy() error { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) + if err != nil { + return fmt.Errorf("creating DELETE %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return fmt.Errorf("client DELETE %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + +func (c *AdminClient) GetProvisionerPolicy(provisionerName string) (*linkedca.Policy, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) + if err != nil { + return nil, fmt.Errorf("creating GET %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client GET %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) CreateProvisionerPolicy(provisionerName string, p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client POST %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) UpdateProvisionerPolicy(provisionerName string, p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client PUT %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) RemoveProvisionerPolicy(provisionerName string) error { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) + tok, err := c.generateAdminToken(u) + if err != nil { + return fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) + if err != nil { + return fmt.Errorf("creating DELETE %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return fmt.Errorf("client DELETE %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + +func (c *AdminClient) GetACMEPolicy(provisionerName, reference, keyID string) (*linkedca.Policy, error) { + var retried bool + var urlPath string + switch { + case keyID != "": + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) + default: + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) + } + u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) + if err != nil { + return nil, fmt.Errorf("creating GET %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client GET %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) CreateACMEPolicy(provisionerName, reference, keyID string, p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + var urlPath string + switch { + case keyID != "": + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) + default: + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) + } + u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client POST %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) UpdateACMEPolicy(provisionerName, reference, keyID string, p *linkedca.Policy) (*linkedca.Policy, error) { + var retried bool + body, err := protojson.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %w", err) + } + var urlPath string + switch { + case keyID != "": + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) + default: + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) + } + u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) + tok, err := c.generateAdminToken(u) + if err != nil { + return nil, fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("client PUT %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var policy = new(linkedca.Policy) + if err := readProtoJSON(resp.Body, policy); err != nil { + return nil, fmt.Errorf("error reading %s: %w", u, err) + } + return policy, nil +} + +func (c *AdminClient) RemoveACMEPolicy(provisionerName, reference, keyID string) error { + var retried bool + var urlPath string + switch { + case keyID != "": + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) + default: + urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) + } + u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) + tok, err := c.generateAdminToken(u) + if err != nil { + return fmt.Errorf("error generating admin token: %w", err) + } + req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) + if err != nil { + return fmt.Errorf("creating DELETE %s request failed: %w", u, err) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return fmt.Errorf("client DELETE %s failed: %w", u, err) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + func readAdminError(r io.ReadCloser) error { // TODO: not all errors can be read (i.e. 404); seems to be a bigger issue defer r.Close() diff --git a/ca/ca.go b/ca/ca.go index 0380d166..3cb4646b 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -219,7 +219,8 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { adminDB := auth.GetAdminDatabase() if adminDB != nil { acmeAdminResponder := adminAPI.NewACMEAdminResponder() - adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder) + policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB, acmeDB) + adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder, policyAdminResponder) mux.Route("/admin", func(r chi.Router) { adminHandler.Route(r) }) diff --git a/go.mod b/go.mod index 763a96de..36e39b4d 100644 --- a/go.mod +++ b/go.mod @@ -14,20 +14,26 @@ require ( github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect github.com/Masterminds/sprig/v3 v3.2.2 github.com/ThalesIgnite/crypto11 v1.2.4 - github.com/aws/aws-sdk-go v1.30.29 + github.com/aws/aws-sdk-go v1.37.0 github.com/dgraph-io/ristretto v0.0.4-0.20200906165740-41ebdbffecfd // indirect + github.com/fatih/color v1.9.0 // indirect + github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect github.com/go-chi/chi v4.0.2+incompatible github.com/go-kit/kit v0.10.0 // indirect github.com/go-piv/piv-go v1.7.0 + github.com/go-sql-driver/mysql v1.6.0 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.7 github.com/google/uuid v1.3.0 github.com/googleapis/gax-go/v2 v2.1.1 github.com/hashicorp/vault/api v1.3.1 github.com/hashicorp/vault/api/auth/approle v0.1.1 + github.com/jhump/protoreflect v1.9.0 // indirect github.com/mattn/go-colorable v0.1.8 // indirect github.com/mattn/go-isatty v0.0.13 // indirect github.com/micromdm/scep/v2 v2.1.0 + github.com/miekg/pkcs11 v1.0.3 // indirect github.com/newrelic/go-agent v2.15.0+incompatible github.com/pkg/errors v0.9.1 github.com/rs/xid v1.2.1 @@ -37,17 +43,21 @@ require ( github.com/smallstep/nosql v0.4.0 github.com/stretchr/testify v1.7.1 github.com/urfave/cli v1.22.4 + go.etcd.io/bbolt v1.3.6 // indirect go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 go.step.sm/crypto v0.16.1 - go.step.sm/linkedca v0.15.0 + go.step.sm/linkedca v0.16.1 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 - golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd + golang.org/x/net v0.0.0-20220403103023-749bd193bc2b + golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 // indirect + golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect google.golang.org/api v0.70.0 - google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf - google.golang.org/grpc v1.44.0 - google.golang.org/protobuf v1.27.1 + google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de + google.golang.org/grpc v1.45.0 + google.golang.org/protobuf v1.28.0 gopkg.in/square/go-jose.v2 v2.6.0 + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) // replace github.com/smallstep/nosql => ../nosql diff --git a/go.sum b/go.sum index b8c908ef..55528b4f 100644 --- a/go.sum +++ b/go.sum @@ -143,8 +143,8 @@ github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6l github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go v1.30.29 h1:NXNqBS9hjOCpDL8SyCyl38gZX3LLLunKOJc5E7vJ8P0= -github.com/aws/aws-sdk-go v1.30.29/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go v1.37.0 h1:GzFnhOIsrGyQ69s7VgqtrG2BG8v7X7vwB3Xpbd/DBBk= +github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= @@ -235,13 +235,15 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.m github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch/v5 v5.5.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4= -github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= +github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= +github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c= +github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= github.com/frankban/quicktest v1.10.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y= @@ -269,8 +271,9 @@ github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG github.com/go-piv/piv-go v1.7.0 h1:rfjdFdASfGV5KLJhSjgpGJ5lzVZVtRWn8ovy/H9HQ/U= github.com/go-piv/piv-go v1.7.0/go.mod h1:ON2WvQncm7dIkCQ7kYJs+nc3V4jHGfrrJnSF8HKy7Gk= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -287,8 +290,9 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= @@ -369,6 +373,7 @@ github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pf github.com/googleapis/gax-go/v2 v2.1.1 h1:dp3bWCh+PPO1zjRRiCSczJav13sBvG4UhNyVTa1KqdU= github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.4.0/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= @@ -511,11 +516,14 @@ github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= -github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= github.com/jhump/protoreflect v1.6.0/go.mod h1:eaTn3RZAmMBcV0fifFvlm6VHNz3wSkYyXYWUh7ymB74= +github.com/jhump/protoreflect v1.9.0 h1:npqHz788dryJiR/l6K/RUQAyh2SwV91+d1dnh4RjO9w= +github.com/jhump/protoreflect v1.9.0/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= -github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -585,8 +593,9 @@ github.com/micromdm/scep/v2 v2.1.0 h1:2fS9Rla7qRR266hvUoEauBJ7J6FhgssEiq2OkSKXma github.com/micromdm/scep/v2 v2.1.0/go.mod h1:BkF7TkPPhmgJAMtHfP+sFTKXmgzNJgLQlvvGoOExBcc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4= -github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f h1:eVB9ELsoq5ouItQBr5Tj334bhPJG/MX+m7rTchmzVUQ= github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/miekg/pkcs11 v1.0.3 h1:iMwmD7I5225wv84WxIG/bmxz9AXjWvTWIbM/TYHvWtw= +github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= @@ -624,6 +633,7 @@ github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OS github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f/go.mod h1:nwPd6pDNId/Xi16qtKrFHrauSwMNuvk+zcjk89wrnlA= github.com/newrelic/go-agent v2.15.0+incompatible h1:IB0Fy+dClpBq9aEoIrLyQXzU34JyI1xVTanPLB/+jvU= github.com/newrelic/go-agent v2.15.0+incompatible/go.mod h1:a8Fv1b/fYhFSReoTU6HDkTYIMZeSVNffmoS726Y0LzQ= +github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso= github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw= github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= @@ -782,8 +792,9 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1 github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= +go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= +go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= go.mozilla.org/pkcs7 v0.0.0-20210730143726-725912489c62/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk= go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 h1:CCriYyAfq1Br1aIYettdHZTy8mBTIPo7We18TuO/bak= @@ -804,8 +815,10 @@ go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/ go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.16.1 h1:4mnZk21cSxyMGxsEpJwZKKvJvDu1PN09UVrWWFNUBdk= go.step.sm/crypto v0.16.1/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= -go.step.sm/linkedca v0.15.0 h1:lEkGRDY+u7FudGKt8yEo7nBy5OzceO9s3rl+/sZVL5M= -go.step.sm/linkedca v0.15.0/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM= +go.step.sm/linkedca v0.16.0 h1:9xdE150lRTEoBq1gJl+prePpSmNqXRXsez3qzRs3Lic= +go.step.sm/linkedca v0.16.0/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM= +go.step.sm/linkedca v0.16.1 h1:CdbMV5SjnlRsgeYTXaaZmQCkYIgJq8BOzpewri57M2k= +go.step.sm/linkedca v0.16.1/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -927,8 +940,9 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211020060615-d418f374d309/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220403103023-749bd193bc2b h1:vI32FkLJNAWtGD4BwkThwEy6XS7ZLLMHkSkYfF8M0W0= +golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1007,6 +1021,7 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1041,8 +1056,9 @@ golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220209214540-3681064d5158 h1:rm+CHSpPEEW2IsXUib1ThaHIjuBVZjxNgSKmBLFfD4c= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4= +golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= @@ -1062,8 +1078,9 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI= golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1108,8 +1125,10 @@ golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWc golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200717024301-6ddee64345a6/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= @@ -1244,8 +1263,9 @@ google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ6 google.golang.org/genproto v0.0.0-20220126215142-9970aeb2e350/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220207164111-0872dc986b00/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220218161850-94dd64e39d7c/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI= -google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf h1:SVYXkUz2yZS9FWb2Gm8ivSlbNQzL2Z/NpPKE3RG2jWk= google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI= +google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de h1:9Ti5SG2U4cAcluryUo/sFay3TQKoxiFMfaT0pbizU7k= +google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo= google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= @@ -1279,8 +1299,9 @@ google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnD google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.40.1/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.41.0/go.mod h1:U3l9uK9J0sini8mHphKoXyaqDA/8VyGnDee1zzIUK6k= -google.golang.org/grpc v1.44.0 h1:weqSxi/TMs1SqFRMHCtBgXRs8k3X39QIDEZ0pRcttUg= google.golang.org/grpc v1.44.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= +google.golang.org/grpc v1.45.0 h1:NEpgUqV3Z+ZjkqMsxMg11IaDrXY4RY6CQukSGK0uI1M= +google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= @@ -1292,10 +1313,12 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.25.1-0.20200805231151-a709e31e5d12/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -1320,11 +1343,13 @@ gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/policy/engine.go b/policy/engine.go new file mode 100755 index 00000000..d1fb4928 --- /dev/null +++ b/policy/engine.go @@ -0,0 +1,301 @@ +package policy + +import ( + "crypto/x509" + "fmt" + "net" + "net/url" + + "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/x509util" +) + +type NamePolicyReason int + +const ( + // NotAllowed results when an instance of NamePolicyEngine + // determines that there's a constraint which doesn't permit + // a DNS or another type of SAN to be signed (or otherwise used). + NotAllowed NamePolicyReason = iota + 1 + // CannotParseDomain is returned when an error occurs + // when parsing the domain part of SAN or subject. + CannotParseDomain + // CannotParseRFC822Name is returned when an error + // occurs when parsing an email address. + CannotParseRFC822Name + // CannotMatch is the type of error returned when + // an error happens when matching SAN types. + CannotMatchNameToConstraint +) + +type NameType string + +const ( + CNNameType NameType = "cn" + DNSNameType NameType = "dns" + IPNameType NameType = "ip" + EmailNameType NameType = "email" + URINameType NameType = "uri" + PrincipalNameType NameType = "principal" +) + +type NamePolicyError struct { + Reason NamePolicyReason + NameType NameType + Name string + detail string +} + +func (e *NamePolicyError) Error() string { + switch e.Reason { + case NotAllowed: + return fmt.Sprintf("%s name %q not allowed", e.NameType, e.Name) + case CannotParseDomain: + return fmt.Sprintf("cannot parse %s domain %q", e.NameType, e.Name) + case CannotParseRFC822Name: + return fmt.Sprintf("cannot parse %s rfc822Name %q", e.NameType, e.Name) + case CannotMatchNameToConstraint: + return fmt.Sprintf("error matching %s name %q to constraint", e.NameType, e.Name) + default: + return fmt.Sprintf("unknown error reason (%d): %s", e.Reason, e.detail) + } +} + +func (e *NamePolicyError) Detail() string { + return e.detail +} + +// NamePolicyEngine can be used to check that a CSR or Certificate meets all allowed and +// denied names before a CA creates and/or signs the Certificate. +// TODO(hs): the X509 RFC also defines name checks on directory name; support that? +// TODO(hs): implement Stringer interface: describe the contents of the NamePolicyEngine? +// TODO(hs): implement matching URI schemes, paths, etc; not just the domain part of URI domains + +type NamePolicyEngine struct { + + // verifySubjectCommonName is set when Subject Common Name must be verified + verifySubjectCommonName bool + // allowLiteralWildcardNames allows literal wildcard DNS domains + allowLiteralWildcardNames bool + + // permitted and exluded constraints similar to x509 Name Constraints + permittedCommonNames []string + excludedCommonNames []string + permittedDNSDomains []string + excludedDNSDomains []string + permittedIPRanges []*net.IPNet + excludedIPRanges []*net.IPNet + permittedEmailAddresses []string + excludedEmailAddresses []string + permittedURIDomains []string + excludedURIDomains []string + permittedPrincipals []string + excludedPrincipals []string + + // some internal counts for housekeeping + numberOfCommonNameConstraints int + numberOfDNSDomainConstraints int + numberOfIPRangeConstraints int + numberOfEmailAddressConstraints int + numberOfURIDomainConstraints int + numberOfPrincipalConstraints int + totalNumberOfPermittedConstraints int + totalNumberOfExcludedConstraints int + totalNumberOfConstraints int +} + +// NewNamePolicyEngine creates a new NamePolicyEngine with NamePolicyOptions +func New(opts ...NamePolicyOption) (*NamePolicyEngine, error) { + + e := &NamePolicyEngine{} + for _, option := range opts { + if err := option(e); err != nil { + return nil, err + } + } + + e.permittedCommonNames = removeDuplicates(e.permittedCommonNames) + e.permittedDNSDomains = removeDuplicates(e.permittedDNSDomains) + e.permittedIPRanges = removeDuplicateIPNets(e.permittedIPRanges) + e.permittedEmailAddresses = removeDuplicates(e.permittedEmailAddresses) + e.permittedURIDomains = removeDuplicates(e.permittedURIDomains) + e.permittedPrincipals = removeDuplicates(e.permittedPrincipals) + + e.excludedCommonNames = removeDuplicates(e.excludedCommonNames) + e.excludedDNSDomains = removeDuplicates(e.excludedDNSDomains) + e.excludedIPRanges = removeDuplicateIPNets(e.excludedIPRanges) + e.excludedEmailAddresses = removeDuplicates(e.excludedEmailAddresses) + e.excludedURIDomains = removeDuplicates(e.excludedURIDomains) + e.excludedPrincipals = removeDuplicates(e.excludedPrincipals) + + e.numberOfCommonNameConstraints = len(e.permittedCommonNames) + len(e.excludedCommonNames) + e.numberOfDNSDomainConstraints = len(e.permittedDNSDomains) + len(e.excludedDNSDomains) + e.numberOfIPRangeConstraints = len(e.permittedIPRanges) + len(e.excludedIPRanges) + e.numberOfEmailAddressConstraints = len(e.permittedEmailAddresses) + len(e.excludedEmailAddresses) + e.numberOfURIDomainConstraints = len(e.permittedURIDomains) + len(e.excludedURIDomains) + e.numberOfPrincipalConstraints = len(e.permittedPrincipals) + len(e.excludedPrincipals) + + e.totalNumberOfPermittedConstraints = len(e.permittedCommonNames) + len(e.permittedDNSDomains) + + len(e.permittedIPRanges) + len(e.permittedEmailAddresses) + len(e.permittedURIDomains) + + len(e.permittedPrincipals) + + e.totalNumberOfExcludedConstraints = len(e.excludedCommonNames) + len(e.excludedDNSDomains) + + len(e.excludedIPRanges) + len(e.excludedEmailAddresses) + len(e.excludedURIDomains) + + len(e.excludedPrincipals) + + e.totalNumberOfConstraints = e.totalNumberOfPermittedConstraints + e.totalNumberOfExcludedConstraints + + return e, nil +} + +// removeDuplicates returns a new slice of strings with +// duplicate values removed. It retains the order of elements +// in the source slice. +func removeDuplicates(items []string) (ret []string) { + + // no need to remove dupes; return original + if len(items) <= 1 { + return items + } + + keys := make(map[string]struct{}, len(items)) + + ret = make([]string, 0, len(items)) + for _, item := range items { + if _, ok := keys[item]; ok { + continue + } + + keys[item] = struct{}{} + ret = append(ret, item) + } + + return +} + +// removeDuplicateIPNets returns a new slice of net.IPNets with +// duplicate values removed. It retains the order of elements in +// the source slice. An IPNet is considered duplicate if its CIDR +// notation exists multiple times in the slice. +func removeDuplicateIPNets(items []*net.IPNet) (ret []*net.IPNet) { + + // no need to remove dupes; return original + if len(items) <= 1 { + return items + } + + keys := make(map[string]struct{}, len(items)) + + ret = make([]*net.IPNet, 0, len(items)) + for _, item := range items { + key := item.String() // use CIDR notation as key + if _, ok := keys[key]; ok { + continue + } + + keys[key] = struct{}{} + ret = append(ret, item) + } + + // TODO(hs): implement filter of fully overlapping ranges, + // so that the smaller ones are automatically removed? + + return +} + +// IsX509CertificateAllowed verifies that all SANs in a Certificate are allowed. +func (e *NamePolicyEngine) IsX509CertificateAllowed(cert *x509.Certificate) error { + if err := e.validateNames(cert.DNSNames, cert.IPAddresses, cert.EmailAddresses, cert.URIs, []string{}); err != nil { + return err + } + + if e.verifySubjectCommonName { + return e.validateCommonName(cert.Subject.CommonName) + } + + return nil +} + +// IsX509CertificateRequestAllowed verifies that all names in the CSR are allowed. +func (e *NamePolicyEngine) IsX509CertificateRequestAllowed(csr *x509.CertificateRequest) error { + if err := e.validateNames(csr.DNSNames, csr.IPAddresses, csr.EmailAddresses, csr.URIs, []string{}); err != nil { + return err + } + + if e.verifySubjectCommonName { + return e.validateCommonName(csr.Subject.CommonName) + } + + return nil +} + +// AreSANSAllowed verifies that all names in the slice of SANs are allowed. +// The SANs are first split into DNS names, IPs, email addresses and URIs. +func (e *NamePolicyEngine) AreSANsAllowed(sans []string) error { + dnsNames, ips, emails, uris := x509util.SplitSANs(sans) + if err := e.validateNames(dnsNames, ips, emails, uris, []string{}); err != nil { + return err + } + return nil +} + +// IsDNSAllowed verifies a single DNS domain is allowed. +func (e *NamePolicyEngine) IsDNSAllowed(dns string) error { + if err := e.validateNames([]string{dns}, []net.IP{}, []string{}, []*url.URL{}, []string{}); err != nil { + return err + } + return nil +} + +// IsIPAllowed verifies a single IP domain is allowed. +func (e *NamePolicyEngine) IsIPAllowed(ip net.IP) error { + if err := e.validateNames([]string{}, []net.IP{ip}, []string{}, []*url.URL{}, []string{}); err != nil { + return err + } + return nil +} + +// IsSSHCertificateAllowed verifies that all principals in an SSH certificate are allowed. +func (e *NamePolicyEngine) IsSSHCertificateAllowed(cert *ssh.Certificate) error { + dnsNames, ips, emails, principals, err := splitSSHPrincipals(cert) + if err != nil { + return err + } + if err := e.validateNames(dnsNames, ips, emails, []*url.URL{}, principals); err != nil { + return err + } + return nil +} + +// splitPrincipals splits SSH certificate principals into DNS names, emails and usernames. +func splitSSHPrincipals(cert *ssh.Certificate) (dnsNames []string, ips []net.IP, emails, principals []string, err error) { + dnsNames = []string{} + ips = []net.IP{} + emails = []string{} + principals = []string{} + var uris []*url.URL + switch cert.CertType { + case ssh.HostCert: + dnsNames, ips, emails, uris = x509util.SplitSANs(cert.ValidPrincipals) + if len(uris) > 0 { + err = fmt.Errorf("URL principals %v not expected in SSH host certificate ", uris) + } + case ssh.UserCert: + // re-using SplitSANs results in anything that can't be parsed as an IP, URI or email + // to be considered a username principal. This allows usernames like h.slatman to be present + // in the SSH certificate. We're exluding URIs, because they can be confusing + // when used in a SSH user certificate. + principals, ips, emails, uris = x509util.SplitSANs(cert.ValidPrincipals) + if len(ips) > 0 { + err = fmt.Errorf("IP principals %v not expected in SSH user certificate ", ips) + } + if len(uris) > 0 { + err = fmt.Errorf("URL principals %v not expected in SSH user certificate ", uris) + } + default: + err = fmt.Errorf("unexpected SSH certificate type %d", cert.CertType) + } + + return +} diff --git a/policy/engine_test.go b/policy/engine_test.go new file mode 100755 index 00000000..1280d14d --- /dev/null +++ b/policy/engine_test.go @@ -0,0 +1,3401 @@ +package policy + +import ( + "crypto/x509" + "crypto/x509/pkix" + "errors" + "net" + "net/url" + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ssh" +) + +// TODO(hs): the functionality in the policy engine is a nice candidate for trying fuzzing on +// TODO(hs): more complex test use cases that combine multiple names and permitted/excluded entries? + +func TestNamePolicyEngine_matchDomainConstraint(t *testing.T) { + tests := []struct { + name string + allowLiteralWildcardNames bool + domain string + constraint string + want bool + wantErr bool + }{ + { + name: "fail/wildcard", + domain: "host.local", + constraint: ".example.com", // internally we're using the x509 period prefix as the indicator for exactly one subdomain + want: false, + wantErr: false, + }, + { + name: "fail/wildcard-literal", + domain: "*.example.com", + constraint: ".example.com", // internally we're using the x509 period prefix as the indicator for exactly one subdomain + want: false, + wantErr: false, + }, + { + name: "fail/specific-domain", + domain: "www.example.com", + constraint: "host.example.com", + want: false, + wantErr: false, + }, + { + name: "fail/single-whitespace-domain", + domain: " ", + constraint: "host.example.com", + want: false, + wantErr: false, + }, + { + name: "fail/period-domain", + domain: ".host.example.com", + constraint: ".example.com", + want: false, + wantErr: false, + }, + { + name: "fail/wrong-asterisk-prefix", + domain: "*Xexample.com", + constraint: ".example.com", + want: false, + wantErr: false, + }, + { + name: "fail/asterisk-in-domain", + domain: "e*ample.com", + constraint: ".com", + want: false, + wantErr: false, + }, + { + name: "fail/asterisk-label", + domain: "example.*.local", + constraint: ".local", + want: false, + wantErr: false, + }, + { + name: "fail/multiple-periods", + domain: "example.local", + constraint: "..local", + want: false, + wantErr: false, + }, + { + name: "fail/error-parsing-domain", + domain: string(byte(0)), + constraint: ".local", + want: false, + wantErr: true, + }, + { + name: "fail/error-parsing-constraint", + domain: "example.local", + constraint: string(byte(0)), + want: false, + wantErr: true, + }, + { + name: "fail/no-subdomain", + domain: "local", + constraint: ".local", + want: false, + wantErr: false, + }, + { + name: "fail/too-many-subdomains", + domain: "www.example.local", + constraint: ".local", + want: false, + wantErr: false, + }, + { + name: "fail/wrong-domain", + domain: "example.notlocal", + constraint: ".local", + want: false, + wantErr: false, + }, + { + name: "false/idna-internationalized-domain-name", + domain: "JP納豆.例.jp", // Example value from https://www.w3.org/International/articles/idn-and-iri/ + constraint: ".例.jp", + want: false, + wantErr: true, + }, + { + name: "false/idna-internationalized-domain-name-constraint", + domain: "xn--jp-cd2fp15c.xn--fsq.jp", // Example value from https://www.w3.org/International/articles/idn-and-iri/ + constraint: ".例.jp", + want: false, + wantErr: true, + }, + { + name: "ok/empty-constraint", + domain: "www.example.com", + constraint: "", + want: true, + wantErr: false, + }, + { + name: "ok/wildcard", + domain: "www.example.com", + constraint: ".example.com", // internally we're using the x509 period prefix as the indicator for exactly one subdomain + want: true, + wantErr: false, + }, + { + name: "ok/wildcard-literal", + allowLiteralWildcardNames: true, + domain: "*.example.com", // specifically allowed using an option on the NamePolicyEngine + constraint: ".example.com", // internally we're using the x509 period prefix as the indicator for exactly one subdomain + want: true, + wantErr: false, + }, + { + name: "ok/specific-domain", + domain: "www.example.com", + constraint: "www.example.com", + want: true, + wantErr: false, + }, + { + name: "ok/different-case", + domain: "WWW.EXAMPLE.com", + constraint: "www.example.com", + want: true, + wantErr: false, + }, + { + name: "ok/idna-internationalized-domain-name-punycode", + domain: "xn--jp-cd2fp15c.xn--fsq.jp", // Example value from https://www.w3.org/International/articles/idn-and-iri/ + constraint: ".xn--fsq.jp", + want: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + engine := NamePolicyEngine{ + allowLiteralWildcardNames: tt.allowLiteralWildcardNames, + } + got, err := engine.matchDomainConstraint(tt.domain, tt.constraint) + if (err != nil) != tt.wantErr { + t.Errorf("NamePolicyEngine.matchDomainConstraint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("NamePolicyEngine.matchDomainConstraint() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_matchIPConstraint(t *testing.T) { + nat64IP, nat64Net, err := net.ParseCIDR("64:ff9b::/96") + assert.NoError(t, err) + tests := []struct { + name string + ip net.IP + constraint *net.IPNet + want bool + wantErr bool + }{ + { + name: "false/ipv4-in-ipv6-nat64", + ip: net.ParseIP("192.0.2.128"), + constraint: nat64Net, + want: false, + wantErr: false, + }, + { + name: "ok/ipv4", + ip: net.ParseIP("127.0.0.1"), + constraint: &net.IPNet{ + IP: net.ParseIP("127.0.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + want: true, + wantErr: false, + }, + { + name: "ok/ipv6", + ip: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7335"), + constraint: &net.IPNet{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + want: true, + wantErr: false, + }, + { + name: "ok/ipv4-in-ipv6", // ipv4 in ipv6 addresses are considered the same in the current implementation, because Go parses them as IPv4 + ip: net.ParseIP("::ffff:192.0.2.128"), + constraint: &net.IPNet{ + IP: net.ParseIP("192.0.2.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + want: true, + wantErr: false, + }, + { + name: "ok/ipv4-in-ipv6-nat64-fixed-ip", + ip: nat64IP, + constraint: nat64Net, + want: true, + wantErr: false, + }, + { + name: "ok/ipv4-in-ipv6-nat64", + ip: net.ParseIP("64:ff9b::192.0.2.129"), + constraint: nat64Net, + want: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := matchIPConstraint(tt.ip, tt.constraint) + if (err != nil) != tt.wantErr { + t.Errorf("matchIPConstraint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("matchIPConstraint() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNamePolicyEngine_matchEmailConstraint(t *testing.T) { + + tests := []struct { + name string + engine *NamePolicyEngine + mailbox rfc2821Mailbox + constraint string + want bool + wantErr bool + }{ + { + name: "fail/asterisk-label", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: "@host.*.example.com", + want: false, + wantErr: true, + }, + { + name: "fail/asterisk-inside-domain", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: "@h*st.example.com", + want: false, + wantErr: true, + }, + { + name: "fail/parse-email", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: "@example.com", + want: false, + wantErr: true, + }, + { + name: "fail/wildcard", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: "example.com", + want: false, + wantErr: false, + }, + { + name: "fail/wildcard-x509-period", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: ".local", + want: false, + wantErr: false, + }, + { + name: "fail/specific-mail-wrong-domain", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: "mail@example.com", + want: false, + wantErr: false, + }, + { + name: "fail/specific-mail-wrong-local", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "root", + domain: "example.com", + }, + constraint: "mail@example.com", + want: false, + wantErr: false, + }, + { + name: "ok/wildcard", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: "local", // "wildcard" for the local domain + want: true, + wantErr: false, + }, + { + name: "ok/wildcard-x509-period", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "example.local", + }, + constraint: ".local", // "wildcard" for the local domain; requires exactly 1 subdomain + want: true, + wantErr: false, + }, + { + name: "ok/asterisk-prefix", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: "*@example.com", + want: false, + wantErr: false, + }, + { + name: "ok/asterisk-prefix-match", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "*", + domain: "example.com", + }, + constraint: "*@example.com", + want: true, + wantErr: false, + }, + { + name: "ok/asterisk-inside-local", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: "m*il@local", + want: false, + wantErr: false, + }, + { + name: "ok/asterisk-inside-local-match", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "m*il", + domain: "local", + }, + constraint: "m*il@local", + want: true, + wantErr: false, + }, + { + name: "ok/specific-mail", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "local", + }, + constraint: "mail@local", + want: true, + wantErr: false, + }, + { + name: "ok/wildcard-tld", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "example.com", + }, + constraint: "example.com", // "wildcard" for 'example.com' + want: true, + wantErr: false, + }, + { + name: "ok/different-case", + engine: &NamePolicyEngine{}, + mailbox: rfc2821Mailbox{ + local: "mail", + domain: "EXAMPLE.com", + }, + constraint: "example.com", // "wildcard" for 'example.com' + want: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.engine.matchEmailConstraint(tt.mailbox, tt.constraint) + if (err != nil) != tt.wantErr { + t.Errorf("NamePolicyEngine.matchEmailConstraint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("NamePolicyEngine.matchEmailConstraint() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNamePolicyEngine_matchURIConstraint(t *testing.T) { + tests := []struct { + name string + engine *NamePolicyEngine + uri *url.URL + constraint string + want bool + wantErr bool + }{ + { + name: "fail/empty-host", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "", + }, + constraint: ".local", + want: false, + wantErr: true, + }, + { + name: "fail/host-with-asterisk-prefix", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "*.local", + }, + constraint: ".local", + want: false, + wantErr: true, + }, + { + name: "fail/host-with-asterisk-label", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "host.*.local", + }, + constraint: ".local", + want: false, + wantErr: true, + }, + { + name: "fail/host-with-asterisk-inside", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "h*st.local", + }, + constraint: ".local", + want: false, + wantErr: true, + }, + { + name: "fail/wildcard", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "www.example.notlocal", + }, + constraint: ".example.local", // using x509 period as the "wildcard"; expects a single subdomain + want: false, + wantErr: false, + }, + { + name: "fail/wildcard-subdomains-too-deep", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "www.sub.example.local", + }, + constraint: ".example.local", // using x509 period as the "wildcard"; expects a single subdomain + want: false, + wantErr: false, + }, + { + name: "fail/host-with-port-split-error", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "www.example.local::8080", + }, + constraint: ".example.local", // using x509 period as the "wildcard"; expects a single subdomain + want: false, + wantErr: true, + }, + { + name: "fail/host-with-ipv4", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "127.0.0.1", + }, + constraint: ".example.local", // using x509 period as the "wildcard"; expects a single subdomain + want: false, + wantErr: true, + }, + { + name: "fail/host-with-ipv6", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", + }, + constraint: ".example.local", // using x509 period as the "wildcard"; expects a single subdomain + want: false, + wantErr: true, + }, + { + name: "ok/wildcard", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "www.example.local", + }, + constraint: ".example.local", // using x509 period as the "wildcard"; expects a single subdomain + want: true, + wantErr: false, + }, + { + name: "ok/host-with-port", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "www.example.local:8080", + }, + constraint: ".example.local", // using x509 period as the "wildcard"; expects a single subdomain + want: true, + wantErr: false, + }, + { + name: "ok/different-case", + engine: &NamePolicyEngine{}, + uri: &url.URL{ + Scheme: "https", + Host: "www.EXAMPLE.local", + }, + constraint: ".example.local", // using x509 period as the "wildcard"; expects a single subdomain + want: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.engine.matchURIConstraint(tt.uri, tt.constraint) + if (err != nil) != tt.wantErr { + t.Errorf("NamePolicyEngine.matchURIConstraint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("NamePolicyEngine.matchURIConstraint() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNamePolicyEngine_X509_AllAllowed(t *testing.T) { + tests := []struct { + name string + options []NamePolicyOption + cert *x509.Certificate + want bool + wantErr *NamePolicyError + }{ + // SINGLE SAN TYPE PERMITTED FAILURE TESTS + { + name: "fail/dns-permitted", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"www.example.com"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "www.example.com", + }, + }, + { + name: "fail/dns-permitted-wildcard-literal-x509", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.x509local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{ + "*.x509local", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "*.x509local", + }, + }, + { + name: "fail/dns-permitted-single-host", + options: []NamePolicyOption{ + WithPermittedDNSDomains("host.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"differenthost.local"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "differenthost.local", + }, + }, + { + name: "fail/dns-permitted-no-label", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"local"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "local", + }, + }, + { + name: "fail/dns-permitted-empty-label", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"www..local"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotParseDomain, + NameType: DNSNameType, + Name: "www..local", + }, + }, + { + name: "fail/dns-permitted-dot-domain", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{ + ".local", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: ".local", + }, + }, + { + name: "fail/dns-permitted-wildcard-multiple-subdomains", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{ + "sub.example.local", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "sub.example.local", + }, + }, + { + name: "fail/dns-permitted-wildcard-literal", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{ + "*.local", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "*.local", + }, + }, + { + name: "fail/dns-permitted-idna-internationalized-domain", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.豆.jp"), + }, + cert: &x509.Certificate{ + DNSNames: []string{ + string(byte(0)) + ".例.jp", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotParseDomain, + NameType: DNSNameType, + Name: string(byte(0)) + ".例.jp", + }, + }, + { + name: "fail/ipv4-permitted", + options: []NamePolicyOption{ + WithPermittedIPRanges( + &net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("1.1.1.1")}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: "1.1.1.1", + }, + }, + { + name: "fail/ipv6-permitted", + options: []NamePolicyOption{ + WithPermittedIPRanges( + &net.IPNet{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + ), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("3001:0db8:85a3:0000:0000:8a2e:0370:7334")}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: "3001:db8:85a3::8a2e:370:7334", // IPv6 is shortened internally + }, + }, + { + name: "fail/mail-permitted-wildcard", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{ + "test@local.com", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "test@local.com", + }, + }, + { + name: "fail/mail-permitted-wildcard-x509", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("example.com"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{ + "test@local.com", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "test@local.com", + }, + }, + { + name: "fail/mail-permitted-specific-mailbox", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("test@local.com"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{ + "root@local.com", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "root@local.com", + }, + }, + { + name: "fail/mail-permitted-wildcard-subdomain", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{ + "test@sub.example.com", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "test@sub.example.com", + }, + }, + { + name: "fail/mail-permitted-idna-internationalized-domain", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@例.jp"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"bücher@例.jp"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotParseRFC822Name, + NameType: EmailNameType, + Name: "bücher@例.jp", + }, + }, + { + name: "fail/mail-permitted-idna-internationalized-domain-rfc822", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@例.jp"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"bücher@例.jp" + string(byte(0))}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotParseRFC822Name, + NameType: EmailNameType, + Name: "bücher@例.jp" + string(byte(0)), + }, + }, + { + name: "fail/mail-permitted-idna-internationalized-domain-ascii", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@例.jp"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@xn---bla.jp"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotParseDomain, + NameType: EmailNameType, + Name: "mail@xn---bla.jp", + }, + }, + { + name: "fail/uri-permitted-domain-wildcard", + options: []NamePolicyOption{ + WithPermittedURIDomains("*.local"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "example.com", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: URINameType, + Name: "https://example.com", + }, + }, + { + name: "fail/uri-permitted", + options: []NamePolicyOption{ + WithPermittedURIDomains("test.local"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "bad.local", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: URINameType, + Name: "https://bad.local", + }, + }, + { + name: "fail/uri-permitted-with-literal-wildcard", // don't allow literal wildcard in URI, e.g. xxxx://*.domain.tld + options: []NamePolicyOption{ + WithPermittedURIDomains("*.local"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "*.local", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotMatchNameToConstraint, + NameType: URINameType, + Name: "https://*.local", + }, + }, + { + name: "fail/uri-permitted-idna-internationalized-domain", + options: []NamePolicyOption{ + WithPermittedURIDomains("*.bücher.example.com"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "abc.bücher.example.com", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotMatchNameToConstraint, + NameType: URINameType, + Name: "https://abc.b%C3%BCcher.example.com", + }, + }, + // SINGLE SAN TYPE EXCLUDED FAILURE TESTS + { + name: "fail/dns-excluded", + options: []NamePolicyOption{ + WithExcludedDNSDomains("*.example.com"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"www.example.com"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "www.example.com", + }, + }, + { + name: "fail/dns-excluded-single-host", + options: []NamePolicyOption{ + WithExcludedDNSDomains("host.example.com"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"host.example.com"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "host.example.com", + }, + }, + { + name: "fail/ipv4-excluded", + options: []NamePolicyOption{ + WithExcludedIPRanges( + &net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: "127.0.0.1", + }, + }, + { + name: "fail/ipv6-excluded", + options: []NamePolicyOption{ + WithExcludedIPRanges( + &net.IPNet{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + ), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: "2001:db8:85a3::8a2e:370:7334", + }, + }, + { + name: "fail/mail-excluded", + options: []NamePolicyOption{ + WithExcludedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@example.com"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "mail@example.com", + }, + }, + { + name: "fail/uri-excluded", + options: []NamePolicyOption{ + WithExcludedURIDomains("*.example.com"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: URINameType, + Name: "https://www.example.com", + }, + }, + { + name: "fail/uri-excluded-with-literal-wildcard", // don't allow literal wildcard in URI, e.g. xxxx://*.domain.tld + options: []NamePolicyOption{ + WithExcludedURIDomains("*.local"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "*.local", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotMatchNameToConstraint, + NameType: URINameType, + Name: "https://*.local", + }, + }, + // SUBJECT FAILURE TESTS + { + name: "fail/subject-permitted-no-match", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedCommonNames("this name is allowed", "and this one too"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "some certificate name", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, // only permitted names allowed + NameType: CNNameType, + Name: "some certificate name", + }, + }, + { + name: "fail/subject-excluded-match", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedCommonNames("this name is not allowed"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "this name is not allowed", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotParseDomain, // CN cannot be parsed as DNS in this case + NameType: CNNameType, + Name: "this name is not allowed", + }, + }, + { + name: "fail/subject-dns-no-domain", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "name with space.local", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: CannotParseDomain, + NameType: CNNameType, + Name: "name with space.local", + }, + }, + { + name: "fail/subject-dns-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "example.notlocal", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "example.notlocal", + }, + }, + { + name: "fail/subject-dns-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "example.local", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "example.local", + }, + }, + { + name: "fail/subject-ipv4-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedIPRanges( + &net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "10.10.10.10", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "10.10.10.10", + }, + }, + { + name: "fail/subject-ipv4-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedIPRanges( + &net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.30", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "127.0.0.30", + }, + }, + { + name: "fail/subject-ipv6-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedIPRanges( + &net.IPNet{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + ), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "2002:0db8:85a3:0000:0000:8a2e:0370:7339", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "2002:db8:85a3::8a2e:370:7339", + }, + }, + { + name: "fail/subject-ipv6-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedIPRanges( + &net.IPNet{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + ), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "2001:0db8:85a3:0000:0000:8a2e:0370:7339", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "2001:db8:85a3::8a2e:370:7339", + }, + }, + { + name: "fail/subject-email-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedEmailAddresses("@example.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "mail@smallstep.com", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "mail@smallstep.com", + }, + }, + { + name: "fail/subject-email-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedEmailAddresses("@example.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "mail@example.local", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "mail@example.local", + }, + }, + { + name: "fail/subject-uri-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedURIDomains("*.example.com"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "https://www.google.com", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "https://www.google.com", + }, + }, + { + name: "fail/subject-uri-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedURIDomains("*.example.com"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "https://www.example.com", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "https://www.example.com", + }, + }, + // DIFFERENT SAN PERMITTED FAILURE TESTS + { + name: "fail/dns-permitted-with-ip-name", // when only DNS is permitted, IPs are not allowed. + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: "127.0.0.1", + }, + }, + { + name: "fail/dns-permitted-with-mail", // when only DNS is permitted, mails are not allowed. + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@smallstep.com"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "mail@smallstep.com", + }, + }, + { + name: "fail/dns-permitted-with-uri", // when only DNS is permitted, URIs are not allowed. + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: URINameType, + Name: "https://www.example.com", + }, + }, + { + name: "fail/ip-permitted-with-dns-name", // when only IP is permitted, DNS names are not allowed. + options: []NamePolicyOption{ + WithPermittedIPRanges( + &net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + DNSNames: []string{"www.example.com"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "www.example.com", + }, + }, + { + name: "fail/ip-permitted-with-mail", // when only IP is permitted, mails are not allowed. + options: []NamePolicyOption{ + WithPermittedIPRanges( + &net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@smallstep.com"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "mail@smallstep.com", + }, + }, + { + name: "fail/ip-permitted-with-uri", // when only IP is permitted, URIs are not allowed. + options: []NamePolicyOption{ + WithPermittedIPRanges( + &net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: URINameType, + Name: "https://www.example.com", + }, + }, + { + name: "fail/mail-permitted-with-dns-name", // when only mail is permitted, DNS names are not allowed. + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"www.example.com"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "www.example.com", + }, + }, + { + name: "fail/mail-permitted-with-ip", // when only mail is permitted, IPs are not allowed. + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{ + net.ParseIP("127.0.0.1"), + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: "127.0.0.1", + }, + }, + { + name: "fail/mail-permitted-with-uri", // when only mail is permitted, URIs are not allowed. + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: URINameType, + Name: "https://www.example.com", + }, + }, + { + name: "fail/uri-permitted-with-dns-name", // when only URI is permitted, DNS names are not allowed. + options: []NamePolicyOption{ + WithPermittedURIDomains("*.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"host.local"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "host.local", + }, + }, + { + name: "fail/uri-permitted-with-ip-name", // when only URI is permitted, IPs are not allowed. + options: []NamePolicyOption{ + WithPermittedURIDomains("*.local"), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{ + net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: "2001:db8:85a3::8a2e:370:7334", + }, + }, + { + name: "fail/uri-permitted-with-ip-name", // when only URI is permitted, mails are not allowed. + options: []NamePolicyOption{ + WithPermittedURIDomains("*.local"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@smallstep.com"}, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "mail@smallstep.com", + }, + }, + // COMBINED FAILURE TESTS + { + name: "fail/combined-simple-all-badhost.local-common-name", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedDNSDomains("*.local"), + WithPermittedCIDRs("127.0.0.1/24"), + WithPermittedEmailAddresses("@example.local"), + WithPermittedURIDomains("*.example.local"), + WithExcludedDNSDomains("badhost.local"), + WithExcludedCIDRs("127.0.0.128/25"), + WithExcludedEmailAddresses("badmail@example.local"), + WithExcludedURIDomains("badwww.example.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "badhost.local", + }, + DNSNames: []string{"example.local"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.40")}, + EmailAddresses: []string{"mail@example.local"}, + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.local", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: CNNameType, + Name: "badhost.local", + }, + }, + { + name: "fail/combined-simple-all-anotherbadhost.local-dns", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + WithPermittedCIDRs("127.0.0.1/24"), + WithPermittedEmailAddresses("@example.local"), + WithPermittedURIDomains("*.example.local"), + WithExcludedDNSDomains("anotherbadhost.local"), + WithExcludedCIDRs("127.0.0.128/25"), + WithExcludedEmailAddresses("badmail@example.local"), + WithExcludedURIDomains("badwww.example.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "badhost.local", + }, + DNSNames: []string{"anotherbadhost.local"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.40")}, + EmailAddresses: []string{"mail@example.local"}, + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.local", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "anotherbadhost.local", + }, + }, + { + name: "fail/combined-simple-all-badmail@example.local", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + WithPermittedCIDRs("127.0.0.1/24"), + WithPermittedEmailAddresses("@example.local"), + WithPermittedURIDomains("*.example.local"), + WithExcludedDNSDomains("badhost.local"), + WithExcludedCIDRs("127.0.0.128/25"), + WithExcludedEmailAddresses("badmail@example.local"), + WithExcludedURIDomains("badwww.example.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "badhost.local", + }, + DNSNames: []string{"example.local"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.40")}, + EmailAddresses: []string{"mail@example.local", "badmail@example.local"}, + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.local", + }, + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "badmail@example.local", + }, + }, + // NO CONSTRAINT SUCCESS TESTS + { + name: "ok/dns-no-constraints", + options: []NamePolicyOption{}, + cert: &x509.Certificate{ + DNSNames: []string{"www.example.com"}, + }, + want: true, + }, + { + name: "ok/ipv4-no-constraints", + options: []NamePolicyOption{}, + cert: &x509.Certificate{ + IPAddresses: []net.IP{ + net.ParseIP("127.0.0.1"), + }, + }, + want: true, + }, + { + name: "ok/ipv6-no-constraints", + options: []NamePolicyOption{}, + cert: &x509.Certificate{ + IPAddresses: []net.IP{ + net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + }, + }, + want: true, + }, + { + name: "ok/mail-no-constraints", + options: []NamePolicyOption{}, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@smallstep.com"}, + }, + want: true, + }, + { + name: "ok/uri-no-constraints", + options: []NamePolicyOption{}, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: true, + }, + { + name: "ok/subject-no-constraints", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "www.example.com", + }, + }, + want: true, + }, + { + name: "ok/subject-empty-no-constraints", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "", + }, + }, + want: true, + }, + { + name: "ok/subject-permitted-match", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedCommonNames("this name is allowed"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "this name is allowed", + }, + }, + want: true, + }, + { + name: "ok/subject-excluded-match", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedCommonNames("this name is not allowed"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "some other name", + }, + }, + want: true, + }, + // SINGLE SAN TYPE PERMITTED SUCCESS TESTS + { + name: "ok/dns-permitted", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"example.local"}, + }, + want: true, + }, + { + name: "ok/dns-permitted-wildcard", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local", "*.x509local"), + WithAllowLiteralWildcardNames(), + }, + cert: &x509.Certificate{ + DNSNames: []string{ + "host.local", + "test.x509local", + }, + }, + want: true, + }, + { + name: "ok/dns-permitted-wildcard-literal", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local", "*.x509local"), + WithAllowLiteralWildcardNames(), + }, + cert: &x509.Certificate{ + DNSNames: []string{ + "*.local", + "*.x509local", + }, + }, + want: true, + }, + { + name: "ok/dns-permitted-combined", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local", "*.x509local", "host.example.com"), + }, + cert: &x509.Certificate{ + DNSNames: []string{ + "example.local", + "example.x509local", + "host.example.com", + }, + }, + want: true, + }, + { + name: "ok/dns-permitted-idna-internationalized-domain", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.例.jp"), + }, + cert: &x509.Certificate{ + DNSNames: []string{ + "JP納豆.例.jp", // Example value from https://www.w3.org/International/articles/idn-and-iri/ + }, + }, + want: true, + }, + { + name: "ok/ipv4-permitted", + options: []NamePolicyOption{ + WithPermittedCIDRs("127.0.0.1/24"), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.20")}, + }, + want: true, + }, + { + name: "ok/ipv6-permitted", + options: []NamePolicyOption{ + WithPermittedCIDRs("2001:0db8:85a3:0000:0000:8a2e:0370:7334/120"), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7339")}, + }, + want: true, + }, + { + name: "ok/mail-permitted-wildcard", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{ + "test@example.com", + }, + }, + want: true, + }, + { + name: "ok/mail-permitted-plain-domain", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("example.com"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{ + "test@example.com", + }, + }, + want: true, + }, + { + name: "ok/mail-permitted-specific-mailbox", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("test@local.com"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{ + "test@local.com", + }, + }, + want: true, + }, + { + name: "ok/mail-permitted-idna-internationalized-domain", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@例.jp"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{}, + }, + want: true, + }, + { + name: "ok/uri-permitted-domain-wildcard", + options: []NamePolicyOption{ + WithPermittedURIDomains("*.local"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "example.local", + }, + }, + }, + want: true, + }, + { + name: "ok/uri-permitted-specific-uri", + options: []NamePolicyOption{ + WithPermittedURIDomains("test.local"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "test.local", + }, + }, + }, + want: true, + }, + { + name: "ok/uri-permitted-with-port", + options: []NamePolicyOption{ + WithPermittedURIDomains("*.example.com"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com:8080", + }, + }, + }, + want: true, + }, + { + name: "ok/uri-permitted-idna-internationalized-domain", + options: []NamePolicyOption{ + WithPermittedURIDomains("*.bücher.example.com"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "abc.xn--bcher-kva.example.com", + }, + }, + }, + want: true, + }, + { + name: "ok/uri-permitted-idna-internationalized-domain", + options: []NamePolicyOption{ + WithPermittedURIDomains("bücher.example.com"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "xn--bcher-kva.example.com", + }, + }, + }, + want: true, + }, + // SINGLE SAN TYPE EXCLUDED SUCCESS TESTS + { + name: "ok/dns-excluded", + options: []NamePolicyOption{ + WithExcludedDNSDomains("*.notlocal"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"example.local"}, + }, + want: true, + }, + { + name: "ok/ipv4-excluded", + options: []NamePolicyOption{ + WithExcludedIPRanges( + &net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("10.10.10.10")}, + }, + want: true, + }, + { + name: "ok/ipv6-excluded", + options: []NamePolicyOption{ + WithExcludedCIDRs("2001:0db8:85a3:0000:0000:8a2e:0370:7334/120"), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("2003:0db8:85a3:0000:0000:8a2e:0370:7334")}, + }, + want: true, + }, + { + name: "ok/mail-excluded", + options: []NamePolicyOption{ + WithExcludedEmailAddresses("@notlocal"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@local"}, + }, + want: true, + }, + { + name: "ok/mail-excluded-with-subdomain", + options: []NamePolicyOption{ + WithExcludedEmailAddresses("@local"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@example.local"}, + }, + want: true, + }, + { + name: "ok/uri-excluded", + options: []NamePolicyOption{ + WithExcludedURIDomains("*.google.com"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: true, + }, + // SUBJECT SUCCESS TESTS + { + name: "ok/subject-empty", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "", + }, + DNSNames: []string{"example.local"}, + }, + want: true, + }, + { + name: "ok/subject-dns-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "example.local", + }, + }, + want: true, + }, + { + name: "ok/subject-dns-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedDNSDomains("*.notlocal"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "example.local", + }, + }, + want: true, + }, + { + name: "ok/subject-ipv4-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedIPRanges( + &net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.20", + }, + }, + want: true, + }, + { + name: "ok/subject-ipv4-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedIPRanges( + &net.IPNet{ + IP: net.ParseIP("128.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + ), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + }, + want: true, + }, + { + name: "ok/subject-ipv6-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedIPRanges( + &net.IPNet{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + ), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "2001:0db8:85a3:0000:0000:8a2e:0370:7339", + }, + }, + want: true, + }, + { + name: "ok/subject-ipv6-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedIPRanges( + &net.IPNet{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Mask: net.CIDRMask(120, 128), + }, + ), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "2009:0db8:85a3:0000:0000:8a2e:0370:7339", + }, + }, + want: true, + }, + { + name: "ok/subject-email-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedEmailAddresses("@example.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "mail@example.local", + }, + }, + want: true, + }, + { + name: "ok/subject-email-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedEmailAddresses("@example.notlocal"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "mail@example.local", + }, + }, + want: true, + }, + { + name: "ok/subject-uri-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedURIDomains("*.example.com"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "https://www.example.com", + }, + }, + want: true, + }, + { + name: "ok/subject-uri-excluded", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedURIDomains("*.smallstep.com"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "https://www.example.com", + }, + }, + want: true, + }, + // DIFFERENT SAN TYPE EXCLUDED SUCCESS TESTS + { + name: "ok/dns-excluded-with-ip-name", // when only DNS is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + }, + want: true, + }, + { + name: "ok/dns-excluded-with-mail", // when only DNS is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@example.com"}, + }, + want: true, + }, + { + name: "ok/dns-excluded-with-mail", // when only DNS is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: true, + }, + { + name: "ok/ip-excluded-with-dns", // when only IP is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedCIDRs("127.0.0.1/24"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"test.local"}, + }, + want: true, + }, + { + name: "ok/ip-excluded-with-mail", // when only IP is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedCIDRs("127.0.0.1/24"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@example.com"}, + }, + want: true, + }, + { + name: "ok/ip-excluded-with-mail", // when only IP is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedCIDRs("127.0.0.1/24"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: true, + }, + { + name: "ok/mail-excluded-with-dns", // when only mail is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"test.local"}, + }, + want: true, + }, + { + name: "ok/mail-excluded-with-ip", // when only mail is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + }, + want: true, + }, + { + name: "ok/mail-excluded-with-uri", // when only mail is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedEmailAddresses("@example.com"), + }, + cert: &x509.Certificate{ + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.com", + }, + }, + }, + want: true, + }, + { + name: "ok/uri-excluded-with-dns", // when only URI is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedURIDomains("*.example.local"), + }, + cert: &x509.Certificate{ + DNSNames: []string{"test.example.local"}, + }, + want: true, + }, + { + name: "ok/uri-excluded-with-dns", // when only URI is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedURIDomains("*.example.local"), + }, + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + }, + want: true, + }, + { + name: "ok/uri-excluded-with-mail", // when only URI is exluded, we allow anything else + options: []NamePolicyOption{ + WithExcludedURIDomains("*.example.local"), + }, + cert: &x509.Certificate{ + EmailAddresses: []string{"mail@example.local"}, + }, + want: true, + }, + { + name: "ok/dns-excluded-with-subject-ip-name", // when only DNS is exluded, we allow anything else + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithExcludedDNSDomains("*.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + }, + want: true, + }, + // COMBINED SUCCESS TESTS + { + name: "ok/combined-simple-permitted", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedDNSDomains("*.local"), + WithPermittedCIDRs("127.0.0.1/24"), + WithPermittedEmailAddresses("@example.local"), + WithPermittedURIDomains("*.example.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "somehost.local", + }, + DNSNames: []string{"example.local"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.15")}, + EmailAddresses: []string{"mail@example.local"}, + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.local", + }, + }, + }, + want: true, + }, + { + name: "ok/combined-simple-permitted-without-subject-verification", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + WithPermittedCIDRs("127.0.0.1/24"), + WithPermittedEmailAddresses("@example.local"), + WithPermittedURIDomains("*.example.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "forbidden-but-non-verified-domain.example.com", + }, + DNSNames: []string{"example.local"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + EmailAddresses: []string{"mail@example.local"}, + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.local", + }, + }, + }, + want: true, + }, + { + name: "ok/combined-simple-all", + options: []NamePolicyOption{ + WithSubjectCommonNameVerification(), + WithPermittedDNSDomains("*.local"), + WithPermittedCIDRs("127.0.0.1/24"), + WithPermittedEmailAddresses("@example.local"), + WithPermittedURIDomains("*.example.local"), + WithExcludedDNSDomains("badhost.local"), + WithExcludedCIDRs("127.0.0.128/25"), + WithExcludedEmailAddresses("badmail@example.local"), + WithExcludedURIDomains("badwww.example.local"), + }, + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "somehost.local", + }, + DNSNames: []string{"example.local"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + EmailAddresses: []string{"mail@example.local"}, + URIs: []*url.URL{ + { + Scheme: "https", + Host: "www.example.local", + }, + }, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + engine, err := New(tt.options...) + assert.NoError(t, err) + assert.NotNil(t, engine) + gotErr := engine.IsX509CertificateAllowed(tt.cert) + wantErr := tt.wantErr != nil + + if (gotErr != nil) != wantErr { + t.Errorf("NamePolicyEngine.IsX509CertificateAllowed() error = %v, wantErr %v", gotErr, tt.wantErr) + return + } + if gotErr != nil { + var npe *NamePolicyError + assert.True(t, errors.As(gotErr, &npe)) + assert.NotEqual(t, "", npe.Error()) + assert.Equal(t, tt.wantErr.Reason, npe.Reason) + assert.Equal(t, tt.wantErr.NameType, npe.NameType) + assert.Equal(t, tt.wantErr.Name, npe.Name) + assert.NotEqual(t, "", npe.Detail()) + //assert.Equals(t, tt.err.Reason, npe.Reason) // NOTE: reason detail is skipped; it's a detail + } + + // Perform the same tests for a CSR, which are similar to Certificates + csr := &x509.CertificateRequest{ + Subject: tt.cert.Subject, + DNSNames: tt.cert.DNSNames, + EmailAddresses: tt.cert.EmailAddresses, + IPAddresses: tt.cert.IPAddresses, + URIs: tt.cert.URIs, + } + gotErr = engine.IsX509CertificateRequestAllowed(csr) + wantErr = tt.wantErr != nil + if (gotErr != nil) != wantErr { + t.Errorf("NamePolicyEngine.AreCSRNamesAllowed() error = %v, wantErr %v", gotErr, tt.wantErr) + return + } + if gotErr != nil { + var npe *NamePolicyError + assert.True(t, errors.As(gotErr, &npe)) + assert.NotEqual(t, "", npe.Error()) + assert.Equal(t, tt.wantErr.Reason, npe.Reason) + assert.Equal(t, tt.wantErr.NameType, npe.NameType) + assert.Equal(t, tt.wantErr.Name, npe.Name) + assert.NotEqual(t, "", npe.Detail()) + //assert.Equals(t, tt.err.Reason, npe.Reason) // NOTE: reason detail is skipped; it's a detail + } + }) + } +} + +func TestNamePolicyEngine_SSH_ArePrincipalsAllowed(t *testing.T) { + tests := []struct { + name string + options []NamePolicyOption + cert *ssh.Certificate + want bool + wantErr *NamePolicyError + }{ + { + name: "fail/host-with-permitted-dns-domain", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "host.example.com", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "host.example.com", + }, + }, + { + name: "fail/host-with-excluded-dns-domain", + options: []NamePolicyOption{ + WithExcludedDNSDomains("*.local"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "host.local", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "host.local", + }, + }, + { + name: "fail/host-with-permitted-cidr", + options: []NamePolicyOption{ + WithPermittedCIDRs("127.0.0.1/24"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "192.168.0.22", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: "192.168.0.22", + }, + }, + { + name: "fail/host-with-excluded-cidr", + options: []NamePolicyOption{ + WithExcludedCIDRs("127.0.0.1/24"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "127.0.0.0", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: "127.0.0.0", + }, + }, + { + name: "fail/user-with-permitted-email", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@example.com"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "mail@local", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "mail@local", + }, + }, + { + name: "fail/user-with-excluded-email", + options: []NamePolicyOption{ + WithExcludedEmailAddresses("@example.com"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "mail@example.com", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "mail@example.com", + }, + }, + { + name: "fail/host-with-permitted-principals", + options: []NamePolicyOption{ + WithPermittedPrincipals("localhost"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "host", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "host", + }, + }, + { + name: "fail/user-with-permitted-principals", + options: []NamePolicyOption{ + WithPermittedPrincipals("user"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "root", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: PrincipalNameType, + Name: "root", + }, + }, + { + name: "fail/user-with-excluded-principals", + options: []NamePolicyOption{ + WithExcludedPrincipals("user"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "user", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: PrincipalNameType, + Name: "user", + }, + }, + { + name: "fail/user-with-permitted-principal-as-mail", + options: []NamePolicyOption{ + WithPermittedPrincipals("ops"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "ops@work", // this is (currently) parsed as an email-like principal; not allowed with just "ops" as the permitted principal + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "ops@work", + }, + }, + { + name: "fail/host-principal-with-permitted-dns-domain", // when only DNS is permitted, username principals are not allowed. + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "user", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "user", + }, + }, + { + name: "fail/host-principal-with-permitted-ip-range", // when only IPs are permitted, username principals are not allowed. + options: []NamePolicyOption{ + WithPermittedCIDRs("127.0.0.1/24"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "user", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "user", + }, + }, + { + name: "fail/user-principal-with-permitted-email", // when only emails are permitted, username principals are not allowed. + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@example.com"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "user", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: PrincipalNameType, + Name: "user", + }, + }, + { + name: "fail/combined-user", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@smallstep.com"), + WithExcludedEmailAddresses("root@smallstep.com"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "someone@smallstep.com", + "someone", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: PrincipalNameType, + Name: "someone", + }, + }, + { + name: "fail/combined-user-with-excluded-user-principal", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@smallstep.com"), + WithExcludedPrincipals("root"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "someone@smallstep.com", + "root", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: PrincipalNameType, + Name: "root", + }, + }, + { + name: "fail/host-with-permitted-user-principals", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@work"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "example.work", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "example.work", + }, + }, + { + name: "fail/user-with-permitted-user-principals", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "herman@work", + }, + }, + want: false, + wantErr: &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: "herman@work", + }, + }, + { + name: "ok/host-with-permitted-dns-domain", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "host.local", + }, + }, + want: true, + }, + { + name: "ok/host-with-excluded-dns-domain", + options: []NamePolicyOption{ + WithExcludedDNSDomains("*.example.com"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "host.local", + }, + }, + want: true, + }, + { + name: "ok/host-with-permitted-ip", + options: []NamePolicyOption{ + WithPermittedCIDRs("127.0.0.1/24"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "127.0.0.33", + }, + }, + want: true, + }, + { + name: "ok/host-with-excluded-ip", + options: []NamePolicyOption{ + WithExcludedCIDRs("127.0.0.1/24"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "192.168.0.35", + }, + }, + want: true, + }, + { + name: "ok/host-with-excluded-principals", + options: []NamePolicyOption{ + WithExcludedPrincipals("localhost"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "localhost", + }, + }, + want: true, + }, + { + name: "ok/user-with-permitted-email", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@example.com"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "mail@example.com", + }, + }, + want: true, + }, + { + name: "ok/user-with-excluded-email", + options: []NamePolicyOption{ + WithExcludedEmailAddresses("@example.com"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "mail@local", + }, + }, + want: true, + }, + { + name: "ok/user-with-permitted-principals", + options: []NamePolicyOption{ + WithPermittedPrincipals("*"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "user", + }, + }, + want: true, + }, + { + name: "ok/user-with-excluded-principals", + options: []NamePolicyOption{ + WithExcludedPrincipals("user"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "root", + }, + }, + want: true, + }, + { + name: "ok/combined-user", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@smallstep.com"), + WithPermittedPrincipals("*"), // without specifying the wildcard, "someone" would not be allowed. + WithExcludedEmailAddresses("root@smallstep.com"), + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "someone@smallstep.com", + "someone", + }, + }, + want: true, + }, + { + name: "ok/combined-user-with-excluded-user-principal", + options: []NamePolicyOption{ + WithPermittedEmailAddresses("@smallstep.com"), + WithExcludedEmailAddresses("root@smallstep.com"), + WithExcludedPrincipals("root"), // unlike the previous test, this implicitly allows any other username principal + }, + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{ + "someone@smallstep.com", + "someone", + }, + }, + want: true, + }, + { + name: "ok/combined-host", + options: []NamePolicyOption{ + WithPermittedDNSDomains("*.local"), + WithPermittedCIDRs("127.0.0.1/24"), + WithExcludedDNSDomains("badhost.local"), + WithExcludedCIDRs("127.0.0.128/25"), + }, + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{ + "example.local", + "127.0.0.31", + }, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + engine, err := New(tt.options...) + assert.NoError(t, err) + gotErr := engine.IsSSHCertificateAllowed(tt.cert) + wantErr := tt.wantErr != nil + if (gotErr != nil) != wantErr { + t.Errorf("NamePolicyEngine.IsSSHCertificateAllowed() error = %v, wantErr %v", gotErr, tt.wantErr) + return + } + if gotErr != nil { + var npe *NamePolicyError + assert.True(t, errors.As(gotErr, &npe)) + assert.NotEqual(t, "", npe.Error()) + assert.Equal(t, tt.wantErr.Reason, npe.Reason) + assert.Equal(t, tt.wantErr.NameType, npe.NameType) + assert.Equal(t, tt.wantErr.Name, npe.Name) + assert.NotEqual(t, "", npe.Detail()) + //assert.Equals(t, tt.err.Reason, npe.Reason) // NOTE: reason detail is skipped; it's a detail + } + }) + } +} + +type result struct { + wantDNSNames []string + wantIps []net.IP + wantEmails []string + wantUsernames []string +} + +func emptyResult() result { + return result{ + wantDNSNames: []string{}, + wantIps: []net.IP{}, + wantEmails: []string{}, + wantUsernames: []string{}, + } +} + +func Test_splitSSHPrincipals(t *testing.T) { + type test struct { + cert *ssh.Certificate + r result + wantErr bool + } + var tests = map[string]func(t *testing.T) test{ + "fail/unexpected-cert-type": func(t *testing.T) test { + r := emptyResult() + return test{ + cert: &ssh.Certificate{ + CertType: uint32(0), + }, + r: r, + wantErr: true, + } + }, + "fail/user-ip": func(t *testing.T) test { + r := emptyResult() + r.wantIps = []net.IP{net.ParseIP("127.0.0.1")} + return test{ + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{"127.0.0.1"}, + }, + r: r, + wantErr: true, + } + }, + "fail/user-uri": func(t *testing.T) test { + r := emptyResult() + return test{ + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{"https://host.local/"}, + }, + r: r, + wantErr: true, + } + }, + "fail/host-uri": func(t *testing.T) test { + r := emptyResult() + return test{ + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{"https://host.local/"}, + }, + r: r, + wantErr: true, + } + }, + "ok/host-dns": func(t *testing.T) test { + r := emptyResult() + r.wantDNSNames = []string{"host.example.com"} + return test{ + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{"host.example.com"}, + }, + r: r, + wantErr: false, + } + }, + "ok/host-ip": func(t *testing.T) test { + r := emptyResult() + r.wantIps = []net.IP{net.ParseIP("127.0.0.1")} + return test{ + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{"127.0.0.1"}, + }, + r: r, + wantErr: false, + } + }, + "ok/host-email": func(t *testing.T) test { + r := emptyResult() + r.wantEmails = []string{"ops@work"} + return test{ + cert: &ssh.Certificate{ + CertType: ssh.HostCert, + ValidPrincipals: []string{"ops@work"}, + }, + r: r, + wantErr: false, + } + }, + "ok/user-localhost": func(t *testing.T) test { + r := emptyResult() + r.wantUsernames = []string{"localhost"} // when type is User cert, this is considered a username; not a DNS + return test{ + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{"localhost"}, + }, + r: r, + wantErr: false, + } + }, + "ok/user-username-with-period": func(t *testing.T) test { + r := emptyResult() + r.wantUsernames = []string{"x.joe"} + return test{ + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{"x.joe"}, + }, + r: r, + wantErr: false, + } + }, + "ok/user-maillike": func(t *testing.T) test { + r := emptyResult() + r.wantEmails = []string{"ops@work"} + return test{ + cert: &ssh.Certificate{ + CertType: ssh.UserCert, + ValidPrincipals: []string{"ops@work"}, + }, + r: r, + wantErr: false, + } + }, + } + for name, prep := range tests { + tt := prep(t) + t.Run(name, func(t *testing.T) { + gotDNSNames, gotIps, gotEmails, gotUsernames, err := splitSSHPrincipals(tt.cert) + if (err != nil) != tt.wantErr { + t.Errorf("splitSSHPrincipals() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !cmp.Equal(tt.r.wantDNSNames, gotDNSNames) { + t.Errorf("splitSSHPrincipals() DNS names diff =\n%s", cmp.Diff(tt.r.wantDNSNames, gotDNSNames)) + } + if !cmp.Equal(tt.r.wantIps, gotIps) { + t.Errorf("splitSSHPrincipals() IPs diff =\n%s", cmp.Diff(tt.r.wantIps, gotIps)) + } + if !cmp.Equal(tt.r.wantEmails, gotEmails) { + t.Errorf("splitSSHPrincipals() Emails diff =\n%s", cmp.Diff(tt.r.wantEmails, gotEmails)) + } + if !cmp.Equal(tt.r.wantUsernames, gotUsernames) { + t.Errorf("splitSSHPrincipals() Usernames diff =\n%s", cmp.Diff(tt.r.wantUsernames, gotUsernames)) + } + }) + } +} + +func Test_removeDuplicates(t *testing.T) { + tests := []struct { + name string + input []string + want []string + }{ + { + name: "empty-slice", + input: []string{}, + want: []string{}, + }, + { + name: "single-item", + input: []string{"x"}, + want: []string{"x"}, + }, + { + name: "ok", + input: []string{"x", "y", "x", "z", "x", "z", "y"}, + want: []string{"x", "y", "z"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := removeDuplicates(tt.input); !reflect.DeepEqual(got, tt.want) { + t.Errorf("removeDuplicates() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_removeDuplicateIPNets(t *testing.T) { + tests := []struct { + name string + input []*net.IPNet + want []*net.IPNet + }{ + { + name: "empty-slice", + input: []*net.IPNet{}, + want: []*net.IPNet{}, + }, + { + name: "single-item", + input: []*net.IPNet{ + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + }, + want: []*net.IPNet{ + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + }, + }, + { + name: "multiple", + input: []*net.IPNet{ + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + { + IP: net.ParseIP("192.168.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + { + IP: net.ParseIP("10.10.0.0"), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + { + IP: net.ParseIP("192.168.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + { + IP: net.ParseIP("192.168.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, + want: []*net.IPNet{ + { + IP: net.ParseIP("127.0.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + { + IP: net.ParseIP("192.168.0.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + { + IP: net.ParseIP("10.10.0.0"), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotRet := removeDuplicateIPNets(tt.input); !reflect.DeepEqual(gotRet, tt.want) { + t.Errorf("removeDuplicateIPNets() = %v, want %v", gotRet, tt.want) + } + }) + } +} + +func TestNamePolicyError_Error(t *testing.T) { + type fields struct { + Reason NamePolicyReason + NameType NameType + Name string + detail string + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "dns-not-allowed", + fields: fields{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: "www.example.com", + }, + want: "dns name \"www.example.com\" not allowed", + }, + { + name: "dns-cannot-parse-domain", + fields: fields{ + Reason: CannotParseDomain, + NameType: DNSNameType, + Name: "www.example.com", + }, + want: "cannot parse dns domain \"www.example.com\"", + }, + { + name: "email-cannot-parse", + fields: fields{ + Reason: CannotParseRFC822Name, + NameType: EmailNameType, + Name: "mail@example.com", + }, + want: "cannot parse email rfc822Name \"mail@example.com\"", + }, + { + name: "uri-cannot-match", + fields: fields{ + Reason: CannotMatchNameToConstraint, + NameType: URINameType, + Name: "https://*.local", + }, + want: "error matching uri name \"https://*.local\" to constraint", + }, + { + name: "unknown", + fields: fields{ + Reason: -1, + NameType: DNSNameType, + Name: "some name", + detail: "detail string", + }, + want: "unknown error reason (-1): detail string", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &NamePolicyError{ + Reason: tt.fields.Reason, + NameType: tt.fields.NameType, + Name: tt.fields.Name, + detail: tt.fields.detail, + } + if got := e.Error(); got != tt.want { + t.Errorf("NamePolicyError.Error() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/policy/options.go b/policy/options.go new file mode 100755 index 00000000..f08f9180 --- /dev/null +++ b/policy/options.go @@ -0,0 +1,386 @@ +package policy + +import ( + "fmt" + "net" + "strings" + + "golang.org/x/net/idna" +) + +type NamePolicyOption func(e *NamePolicyEngine) error + +// TODO: wrap (more) errors; and prove a set of known (exported) errors + +func WithSubjectCommonNameVerification() NamePolicyOption { + return func(e *NamePolicyEngine) error { + e.verifySubjectCommonName = true + return nil + } +} + +func WithAllowLiteralWildcardNames() NamePolicyOption { + return func(e *NamePolicyEngine) error { + e.allowLiteralWildcardNames = true + return nil + } +} + +func WithPermittedCommonNames(commonNames ...string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + normalizedCommonNames := make([]string, len(commonNames)) + for i, commonName := range commonNames { + normalizedCommonName, err := normalizeAndValidateCommonName(commonName) + if err != nil { + return fmt.Errorf("cannot parse permitted common name constraint %q: %w", commonName, err) + } + normalizedCommonNames[i] = normalizedCommonName + } + g.permittedCommonNames = normalizedCommonNames + return nil + } +} + +func WithExcludedCommonNames(commonNames ...string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + normalizedCommonNames := make([]string, len(commonNames)) + for i, commonName := range commonNames { + normalizedCommonName, err := normalizeAndValidateCommonName(commonName) + if err != nil { + return fmt.Errorf("cannot parse excluded common name constraint %q: %w", commonName, err) + } + normalizedCommonNames[i] = normalizedCommonName + } + g.excludedCommonNames = normalizedCommonNames + return nil + } +} + +func WithPermittedDNSDomains(domains ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + normalizedDomains := make([]string, len(domains)) + for i, domain := range domains { + normalizedDomain, err := normalizeAndValidateDNSDomainConstraint(domain) + if err != nil { + return fmt.Errorf("cannot parse permitted domain constraint %q: %w", domain, err) + } + normalizedDomains[i] = normalizedDomain + } + e.permittedDNSDomains = normalizedDomains + return nil + } +} + +func WithExcludedDNSDomains(domains ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + normalizedDomains := make([]string, len(domains)) + for i, domain := range domains { + normalizedDomain, err := normalizeAndValidateDNSDomainConstraint(domain) + if err != nil { + return fmt.Errorf("cannot parse excluded domain constraint %q: %w", domain, err) + } + normalizedDomains[i] = normalizedDomain + } + e.excludedDNSDomains = normalizedDomains + return nil + } +} + +func WithPermittedIPRanges(ipRanges ...*net.IPNet) NamePolicyOption { + return func(e *NamePolicyEngine) error { + e.permittedIPRanges = ipRanges + return nil + } +} + +func WithPermittedCIDRs(cidrs ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + networks := make([]*net.IPNet, len(cidrs)) + for i, cidr := range cidrs { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("cannot parse permitted CIDR constraint %q", cidr) + } + networks[i] = nw + } + e.permittedIPRanges = networks + return nil + } +} + +func WithExcludedCIDRs(cidrs ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + networks := make([]*net.IPNet, len(cidrs)) + for i, cidr := range cidrs { + _, nw, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("cannot parse excluded CIDR constraint %q", cidr) + } + networks[i] = nw + } + e.excludedIPRanges = networks + return nil + } +} + +func WithPermittedIPsOrCIDRs(ipsOrCIDRs ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + networks := make([]*net.IPNet, len(ipsOrCIDRs)) + for i, ipOrCIDR := range ipsOrCIDRs { + _, nw, err := net.ParseCIDR(ipOrCIDR) + if err == nil { + networks[i] = nw + } else if ip := net.ParseIP(ipOrCIDR); ip != nil { + networks[i] = networkFor(ip) + } else { + return fmt.Errorf("cannot parse permitted constraint %q as IP nor CIDR", ipOrCIDR) + } + } + e.permittedIPRanges = networks + return nil + } +} + +func WithExcludedIPsOrCIDRs(ipsOrCIDRs ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + networks := make([]*net.IPNet, len(ipsOrCIDRs)) + for i, ipOrCIDR := range ipsOrCIDRs { + _, nw, err := net.ParseCIDR(ipOrCIDR) + if err == nil { + networks[i] = nw + } else if ip := net.ParseIP(ipOrCIDR); ip != nil { + networks[i] = networkFor(ip) + } else { + return fmt.Errorf("cannot parse excluded constraint %q as IP nor CIDR", ipOrCIDR) + } + } + e.excludedIPRanges = networks + return nil + } +} + +func WithExcludedIPRanges(ipRanges ...*net.IPNet) NamePolicyOption { + return func(e *NamePolicyEngine) error { + e.excludedIPRanges = ipRanges + return nil + } +} + +func WithPermittedEmailAddresses(emailAddresses ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + normalizedEmailAddresses := make([]string, len(emailAddresses)) + for i, email := range emailAddresses { + normalizedEmailAddress, err := normalizeAndValidateEmailConstraint(email) + if err != nil { + return fmt.Errorf("cannot parse permitted email constraint %q: %w", email, err) + } + normalizedEmailAddresses[i] = normalizedEmailAddress + } + e.permittedEmailAddresses = normalizedEmailAddresses + return nil + } +} + +func WithExcludedEmailAddresses(emailAddresses ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + normalizedEmailAddresses := make([]string, len(emailAddresses)) + for i, email := range emailAddresses { + normalizedEmailAddress, err := normalizeAndValidateEmailConstraint(email) + if err != nil { + return fmt.Errorf("cannot parse excluded email constraint %q: %w", email, err) + } + normalizedEmailAddresses[i] = normalizedEmailAddress + } + e.excludedEmailAddresses = normalizedEmailAddresses + return nil + } +} + +func WithPermittedURIDomains(uriDomains ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + normalizedURIDomains := make([]string, len(uriDomains)) + for i, domain := range uriDomains { + normalizedURIDomain, err := normalizeAndValidateURIDomainConstraint(domain) + if err != nil { + return fmt.Errorf("cannot parse permitted URI domain constraint %q: %w", domain, err) + } + normalizedURIDomains[i] = normalizedURIDomain + } + e.permittedURIDomains = normalizedURIDomains + return nil + } +} + +func WithExcludedURIDomains(domains ...string) NamePolicyOption { + return func(e *NamePolicyEngine) error { + normalizedURIDomains := make([]string, len(domains)) + for i, domain := range domains { + normalizedURIDomain, err := normalizeAndValidateURIDomainConstraint(domain) + if err != nil { + return fmt.Errorf("cannot parse excluded URI domain constraint %q: %w", domain, err) + } + normalizedURIDomains[i] = normalizedURIDomain + } + e.excludedURIDomains = normalizedURIDomains + return nil + } +} + +func WithPermittedPrincipals(principals ...string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + g.permittedPrincipals = principals + return nil + } +} + +func WithExcludedPrincipals(principals ...string) NamePolicyOption { + return func(g *NamePolicyEngine) error { + g.excludedPrincipals = principals + return nil + } +} + +func networkFor(ip net.IP) *net.IPNet { + var mask net.IPMask + if !isIPv4(ip) { + mask = net.CIDRMask(128, 128) + } else { + mask = net.CIDRMask(32, 32) + } + nw := &net.IPNet{ + IP: ip, + Mask: mask, + } + return nw +} + +func isIPv4(ip net.IP) bool { + return ip.To4() != nil +} + +func normalizeAndValidateCommonName(constraint string) (string, error) { + normalizedConstraint := strings.ToLower(strings.TrimSpace(constraint)) + if normalizedConstraint == "" { + return "", fmt.Errorf("contraint %q can not be empty or white space string", constraint) + } + if normalizedConstraint == "*" { + return "", fmt.Errorf("wildcard constraint %q is not supported", constraint) + } + return normalizedConstraint, nil +} + +func normalizeAndValidateDNSDomainConstraint(constraint string) (string, error) { + normalizedConstraint := strings.ToLower(strings.TrimSpace(constraint)) + if normalizedConstraint == "" { + return "", fmt.Errorf("contraint %q can not be empty or white space string", constraint) + } + if strings.Contains(normalizedConstraint, "..") { + return "", fmt.Errorf("domain constraint %q cannot have empty labels", constraint) + } + if strings.HasPrefix(normalizedConstraint, ".") { + return "", fmt.Errorf("domain constraint %q with wildcard should start with *", constraint) + } + if strings.LastIndex(normalizedConstraint, "*") > 0 { + return "", fmt.Errorf("domain constraint %q can only have wildcard as starting character", constraint) + } + if len(normalizedConstraint) >= 2 && normalizedConstraint[0] == '*' && normalizedConstraint[1] != '.' { + return "", fmt.Errorf("wildcard character in domain constraint %q can only be used to match (full) labels", constraint) + } + if strings.HasPrefix(normalizedConstraint, "*.") { + normalizedConstraint = normalizedConstraint[1:] // cut off wildcard character; keep the period + } + normalizedConstraint, err := idna.Lookup.ToASCII(normalizedConstraint) + if err != nil { + return "", fmt.Errorf("domain constraint %q can not be converted to ASCII: %w", constraint, err) + } + if _, ok := domainToReverseLabels(normalizedConstraint); !ok { + return "", fmt.Errorf("cannot parse domain constraint %q", constraint) + } + return normalizedConstraint, nil +} + +func normalizeAndValidateEmailConstraint(constraint string) (string, error) { + normalizedConstraint := strings.ToLower(strings.TrimSpace(constraint)) + if normalizedConstraint == "" { + return "", fmt.Errorf("email contraint %q can not be empty or white space string", constraint) + } + if strings.Contains(normalizedConstraint, "*") { + return "", fmt.Errorf("email constraint %q cannot contain asterisk wildcard", constraint) + } + if strings.Count(normalizedConstraint, "@") > 1 { + return "", fmt.Errorf("email constraint %q contains too many @ characters", constraint) + } + if normalizedConstraint[0] == '@' { + normalizedConstraint = normalizedConstraint[1:] // remove the leading @ as wildcard for emails + } + if normalizedConstraint[0] == '.' { + return "", fmt.Errorf("email constraint %q cannot start with period", constraint) + } + if strings.Contains(normalizedConstraint, "@") { + mailbox, ok := parseRFC2821Mailbox(normalizedConstraint) + if !ok { + return "", fmt.Errorf("cannot parse email constraint %q as RFC 2821 mailbox", constraint) + } + // According to RFC 5280, section 7.5, emails are considered to match if the local part is + // an exact match and the host (domain) part matches the ASCII representation (case-insensitive): + // https://datatracker.ietf.org/doc/html/rfc5280#section-7.5 + domainASCII, err := idna.Lookup.ToASCII(mailbox.domain) + if err != nil { + return "", fmt.Errorf("email constraint %q domain part %q cannot be converted to ASCII: %w", constraint, mailbox.domain, err) + } + normalizedConstraint = mailbox.local + "@" + domainASCII + } else { + var err error + normalizedConstraint, err = idna.Lookup.ToASCII(normalizedConstraint) + if err != nil { + return "", fmt.Errorf("email constraint %q cannot be converted to ASCII: %w", constraint, err) + } + } + if _, ok := domainToReverseLabels(normalizedConstraint); !ok { + return "", fmt.Errorf("cannot parse email domain constraint %q", constraint) + } + return normalizedConstraint, nil +} + +func normalizeAndValidateURIDomainConstraint(constraint string) (string, error) { + normalizedConstraint := strings.ToLower(strings.TrimSpace(constraint)) + if normalizedConstraint == "" { + return "", fmt.Errorf("URI domain contraint %q cannot be empty or white space string", constraint) + } + if strings.Contains(normalizedConstraint, "://") { + return "", fmt.Errorf("URI domain constraint %q contains scheme (not supported yet)", constraint) + } + if strings.Contains(normalizedConstraint, "..") { + return "", fmt.Errorf("URI domain constraint %q cannot have empty labels", constraint) + } + if strings.HasPrefix(normalizedConstraint, ".") { + return "", fmt.Errorf("URI domain constraint %q with wildcard should start with *", constraint) + } + if strings.LastIndex(normalizedConstraint, "*") > 0 { + return "", fmt.Errorf("URI domain constraint %q can only have wildcard as starting character", constraint) + } + if strings.HasPrefix(normalizedConstraint, "*.") { + normalizedConstraint = normalizedConstraint[1:] // cut off wildcard character; keep the period + } + // we're being strict with square brackets in domains; we don't allow them, no matter what + if strings.Contains(normalizedConstraint, "[") || strings.Contains(normalizedConstraint, "]") { + return "", fmt.Errorf("URI domain constraint %q contains invalid square brackets", constraint) + } + if _, _, err := net.SplitHostPort(normalizedConstraint); err == nil { + // a successful split (likely) with host and port; we don't currently allow ports in the config + return "", fmt.Errorf("URI domain constraint %q cannot contain port", constraint) + } + // check if the host part of the URI domain constraint is an IP + if net.ParseIP(normalizedConstraint) != nil { + return "", fmt.Errorf("URI domain constraint %q cannot be an IP", constraint) + } + normalizedConstraint, err := idna.Lookup.ToASCII(normalizedConstraint) + if err != nil { + return "", fmt.Errorf("URI domain constraint %q cannot be converted to ASCII: %w", constraint, err) + } + _, ok := domainToReverseLabels(normalizedConstraint) + if !ok { + return "", fmt.Errorf("cannot parse URI domain constraint %q", constraint) + } + return normalizedConstraint, nil +} diff --git a/policy/options_117_test.go b/policy/options_117_test.go new file mode 100644 index 00000000..916eefe2 --- /dev/null +++ b/policy/options_117_test.go @@ -0,0 +1,125 @@ +//go:build !go1.18 +// +build !go1.18 + +package policy + +import "testing" + +func Test_normalizeAndValidateURIDomainConstraint(t *testing.T) { + tests := []struct { + name string + constraint string + want string + wantErr bool + }{ + { + name: "fail/empty-constraint", + constraint: "", + want: "", + wantErr: true, + }, + { + name: "fail/scheme-https", + constraint: `https://*.local`, + want: "", + wantErr: true, + }, + { + name: "fail/too-many-asterisks", + constraint: "**.local", + want: "", + wantErr: true, + }, + { + name: "fail/empty-label", + constraint: "..local", + want: "", + wantErr: true, + }, + { + name: "fail/empty-reverse", + constraint: ".", + want: "", + wantErr: true, + }, + { + name: "fail/no-asterisk", + constraint: ".example.com", + want: "", + wantErr: true, + }, + { + name: "fail/domain-with-port", + constraint: "host.local:8443", + want: "", + wantErr: true, + }, + { + name: "fail/ipv4", + constraint: "127.0.0.1", + want: "", + wantErr: true, + }, + { + name: "fail/ipv6-brackets", + constraint: "[::1]", + want: "", + wantErr: true, + }, + { + name: "fail/ipv6-no-brackets", + constraint: "::1", + want: "", + wantErr: true, + }, + { + name: "fail/ipv6-no-brackets", + constraint: "[::1", + want: "", + wantErr: true, + }, + { + name: "fail/idna-internationalized-domain-name-lookup", + constraint: `\00local`, + want: "", + wantErr: true, + }, + { + name: "ok/wildcard", + constraint: "*.local", + want: ".local", + wantErr: false, + }, + { + name: "ok/specific-domain", + constraint: "example.local", + want: "example.local", + wantErr: false, + }, + { + name: "ok/idna-internationalized-domain-name-lookup", + constraint: `*.bücher.example.com`, + want: ".xn--bcher-kva.example.com", + wantErr: false, + }, + { + // IDNA2003 vs. 2008 deviation: https://unicode.org/reports/tr46/#Deviations results + // in a difference between Go 1.18 and lower versions. Go 1.18 expects ".xn--fa-hia.de"; not .fass.de. + name: "ok/idna-internationalized-domain-name-lookup-deviation", + constraint: `*.faß.de`, + want: ".fass.de", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := normalizeAndValidateURIDomainConstraint(tt.constraint) + if (err != nil) != tt.wantErr { + t.Errorf("normalizeAndValidateURIDomainConstraint() error = %v, wantErr %v", err, tt.wantErr) + } + if got != tt.want { + t.Errorf("normalizeAndValidateURIDomainConstraint() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/policy/options_118_test.go b/policy/options_118_test.go new file mode 100644 index 00000000..6fa2ded4 --- /dev/null +++ b/policy/options_118_test.go @@ -0,0 +1,125 @@ +//go:build go1.18 +// +build go1.18 + +package policy + +import "testing" + +func Test_normalizeAndValidateURIDomainConstraint(t *testing.T) { + tests := []struct { + name string + constraint string + want string + wantErr bool + }{ + { + name: "fail/empty-constraint", + constraint: "", + want: "", + wantErr: true, + }, + { + name: "fail/scheme-https", + constraint: `https://*.local`, + want: "", + wantErr: true, + }, + { + name: "fail/too-many-asterisks", + constraint: "**.local", + want: "", + wantErr: true, + }, + { + name: "fail/empty-label", + constraint: "..local", + want: "", + wantErr: true, + }, + { + name: "fail/empty-reverse", + constraint: ".", + want: "", + wantErr: true, + }, + { + name: "fail/domain-with-port", + constraint: "host.local:8443", + want: "", + wantErr: true, + }, + { + name: "fail/no-asterisk", + constraint: ".example.com", + want: "", + wantErr: true, + }, + { + name: "fail/ipv4", + constraint: "127.0.0.1", + want: "", + wantErr: true, + }, + { + name: "fail/ipv6-brackets", + constraint: "[::1]", + want: "", + wantErr: true, + }, + { + name: "fail/ipv6-no-brackets", + constraint: "::1", + want: "", + wantErr: true, + }, + { + name: "fail/ipv6-no-brackets", + constraint: "[::1", + want: "", + wantErr: true, + }, + { + name: "fail/idna-internationalized-domain-name-lookup", + constraint: `\00local`, + want: "", + wantErr: true, + }, + { + name: "ok/wildcard", + constraint: "*.local", + want: ".local", + wantErr: false, + }, + { + name: "ok/specific-domain", + constraint: "example.local", + want: "example.local", + wantErr: false, + }, + { + name: "ok/idna-internationalized-domain-name-lookup", + constraint: `*.bücher.example.com`, + want: ".xn--bcher-kva.example.com", + wantErr: false, + }, + { + // IDNA2003 vs. 2008 deviation: https://unicode.org/reports/tr46/#Deviations results + // in a difference between Go 1.18 and lower versions. Go 1.18 expects ".xn--fa-hia.de"; not .fass.de. + name: "ok/idna-internationalized-domain-name-lookup-deviation", + constraint: `*.faß.de`, + want: ".xn--fa-hia.de", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := normalizeAndValidateURIDomainConstraint(tt.constraint) + if (err != nil) != tt.wantErr { + t.Errorf("normalizeAndValidateURIDomainConstraint() error = %v, wantErr %v", err, tt.wantErr) + } + if got != tt.want { + t.Errorf("normalizeAndValidateURIDomainConstraint() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/policy/options_test.go b/policy/options_test.go new file mode 100644 index 00000000..697afecf --- /dev/null +++ b/policy/options_test.go @@ -0,0 +1,660 @@ +package policy + +import ( + "net" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" +) + +func Test_normalizeAndValidateCommonName(t *testing.T) { + tests := []struct { + name string + constraint string + want string + wantErr bool + }{ + { + name: "fail/empty-constraint", + constraint: "", + want: "", + wantErr: true, + }, + { + name: "fail/wildcard", + constraint: "*", + want: "", + wantErr: true, + }, + { + name: "ok", + constraint: "step", + want: "step", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := normalizeAndValidateCommonName(tt.constraint) + if (err != nil) != tt.wantErr { + t.Errorf("normalizeAndValidateCommonName() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("normalizeAndValidateCommonName() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_normalizeAndValidateDNSDomainConstraint(t *testing.T) { + tests := []struct { + name string + constraint string + want string + wantErr bool + }{ + { + name: "fail/empty-constraint", + constraint: "", + want: "", + wantErr: true, + }, + { + name: "fail/wildcard-partial-label", + constraint: "*xxxx.local", + want: "", + wantErr: true, + }, + { + name: "fail/wildcard-in-the-middle", + constraint: "x.*.local", + want: "", + wantErr: true, + }, + { + name: "fail/empty-label", + constraint: "..local", + want: "", + wantErr: true, + }, + { + name: "fail/empty-reverse", + constraint: ".", + want: "", + wantErr: true, + }, + { + name: "fail/no-asterisk", + constraint: ".example.com", + want: "", + wantErr: true, + }, + { + name: "fail/idna-internationalized-domain-name-lookup", + constraint: `\00.local`, // invalid IDNA ASCII character + want: "", + wantErr: true, + }, + { + name: "ok/wildcard", + constraint: "*.local", + want: ".local", + wantErr: false, + }, + { + name: "ok/specific-domain", + constraint: "example.local", + want: "example.local", + wantErr: false, + }, + { + name: "ok/idna-internationalized-domain-name-punycode", + constraint: "*.xn--fsq.jp", // Example value from https://www.w3.org/International/articles/idn-and-iri/ + want: ".xn--fsq.jp", + wantErr: false, + }, + { + name: "ok/idna-internationalized-domain-name-lookup-transformed", + constraint: "*.例.jp", // Example value from https://www.w3.org/International/articles/idn-and-iri/ + want: ".xn--fsq.jp", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := normalizeAndValidateDNSDomainConstraint(tt.constraint) + if (err != nil) != tt.wantErr { + t.Errorf("normalizeAndValidateDNSDomainConstraint() error = %v, wantErr %v", err, tt.wantErr) + } + if got != tt.want { + t.Errorf("normalizeAndValidateDNSDomainConstraint() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_normalizeAndValidateEmailConstraint(t *testing.T) { + tests := []struct { + name string + constraint string + want string + wantErr bool + }{ + { + name: "fail/empty-constraint", + constraint: "", + want: "", + wantErr: true, + }, + { + name: "fail/asterisk", + constraint: "*.local", + want: "", + wantErr: true, + }, + { + name: "fail/period", + constraint: ".local", + want: "", + wantErr: true, + }, + { + name: "fail/@period", + constraint: "@.local", + want: "", + wantErr: true, + }, + { + name: "fail/too-many-@s", + constraint: "@local@example.com", + want: "", + wantErr: true, + }, + { + name: "fail/parse-mailbox", + constraint: "mail@example.com" + string(byte(0)), + want: "", + wantErr: true, + }, + { + name: "fail/idna-internationalized-domain", + constraint: "mail@xn--bla.local", + want: "", + wantErr: true, + }, + { + name: "fail/idna-internationalized-domain-name-lookup", + constraint: `\00local`, + want: "", + wantErr: true, + }, + { + name: "fail/parse-domain", + constraint: "x..example.com", + want: "", + wantErr: true, + }, + { + name: "ok/wildcard", + constraint: "@local", + want: "local", + wantErr: false, + }, + { + name: "ok/specific-mail", + constraint: "mail@local", + want: "mail@local", + wantErr: false, + }, + // TODO(hs): fix the below; doesn't get past parseRFC2821Mailbox; I think it should be allowed. + // { + // name: "ok/idna-internationalized-local", + // constraint: `bücher@local`, + // want: "bücher@local", + // wantErr: false, + // }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := normalizeAndValidateEmailConstraint(tt.constraint) + if (err != nil) != tt.wantErr { + t.Errorf("normalizeAndValidateEmailConstraint() error = %v, wantErr %v", err, tt.wantErr) + } + if got != tt.want { + t.Errorf("normalizeAndValidateEmailConstraint() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNew(t *testing.T) { + type test struct { + options []NamePolicyOption + want *NamePolicyEngine + wantErr bool + } + var tests = map[string]func(t *testing.T) test{ + "fail/with-permitted-common-name": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithPermittedCommonNames("*"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-excluded-common-name": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithExcludedCommonNames(""), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-permitted-dns-domains": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithPermittedDNSDomains("**.local"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-excluded-dns-domains": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithExcludedDNSDomains("**.local"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-permitted-cidrs": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithPermittedCIDRs("127.0.0.1//24"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-excluded-cidrs": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithExcludedCIDRs("127.0.0.1//24"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-permitted-ipsOrCIDRs-cidr": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithPermittedIPsOrCIDRs("127.0.0.1//24"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-permitted-ipsOrCIDRs-ip": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithPermittedIPsOrCIDRs("127.0.0:1"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-excluded-ipsOrCIDRs-cidr": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithExcludedIPsOrCIDRs("127.0.0.1//24"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-excluded-ipsOrCIDRs-ip": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithExcludedIPsOrCIDRs("127.0.0:1"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-permitted-emails": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithPermittedEmailAddresses("*.local"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-excluded-emails": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithExcludedEmailAddresses("*.local"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-permitted-uris": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithPermittedURIDomains("**.local"), + }, + want: nil, + wantErr: true, + } + }, + "fail/with-excluded-uris": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{ + WithExcludedURIDomains("**.local"), + }, + want: nil, + wantErr: true, + } + }, + "ok/default": func(t *testing.T) test { + return test{ + options: []NamePolicyOption{}, + want: &NamePolicyEngine{}, + wantErr: false, + } + }, + "ok/subject-verification": func(t *testing.T) test { + options := []NamePolicyOption{ + WithSubjectCommonNameVerification(), + } + return test{ + options: options, + want: &NamePolicyEngine{ + verifySubjectCommonName: true, + }, + wantErr: false, + } + }, + "ok/literal-wildcards": func(t *testing.T) test { + options := []NamePolicyOption{ + WithAllowLiteralWildcardNames(), + } + return test{ + options: options, + want: &NamePolicyEngine{ + allowLiteralWildcardNames: true, + }, + wantErr: false, + } + }, + "ok/with-permitted-dns-wildcard-domains": func(t *testing.T) test { + options := []NamePolicyOption{ + WithPermittedDNSDomains("*.local", "*.example.com"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + permittedDNSDomains: []string{".local", ".example.com"}, + numberOfDNSDomainConstraints: 2, + totalNumberOfPermittedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-excluded-dns-domains": func(t *testing.T) test { + options := []NamePolicyOption{ + WithExcludedDNSDomains("*.local", "*.example.com"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + excludedDNSDomains: []string{".local", ".example.com"}, + numberOfDNSDomainConstraints: 2, + totalNumberOfExcludedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-permitted-ip-ranges": func(t *testing.T) test { + _, nw1, err := net.ParseCIDR("127.0.0.1/24") + assert.NoError(t, err) + _, nw2, err := net.ParseCIDR("192.168.0.1/24") + assert.NoError(t, err) + options := []NamePolicyOption{ + WithPermittedIPRanges(nw1, nw2), + } + return test{ + options: options, + want: &NamePolicyEngine{ + permittedIPRanges: []*net.IPNet{ + nw1, nw2, + }, + numberOfIPRangeConstraints: 2, + totalNumberOfPermittedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-excluded-ip-ranges": func(t *testing.T) test { + _, nw1, err := net.ParseCIDR("127.0.0.1/24") + assert.NoError(t, err) + _, nw2, err := net.ParseCIDR("192.168.0.1/24") + assert.NoError(t, err) + options := []NamePolicyOption{ + WithExcludedIPRanges(nw1, nw2), + } + return test{ + options: options, + want: &NamePolicyEngine{ + excludedIPRanges: []*net.IPNet{ + nw1, nw2, + }, + numberOfIPRangeConstraints: 2, + totalNumberOfExcludedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-permitted-cidrs": func(t *testing.T) test { + _, nw1, err := net.ParseCIDR("127.0.0.1/24") + assert.NoError(t, err) + _, nw2, err := net.ParseCIDR("192.168.0.1/24") + assert.NoError(t, err) + options := []NamePolicyOption{ + WithPermittedCIDRs("127.0.0.1/24", "192.168.0.1/24"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + permittedIPRanges: []*net.IPNet{ + nw1, nw2, + }, + numberOfIPRangeConstraints: 2, + totalNumberOfPermittedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-excluded-cidrs": func(t *testing.T) test { + _, nw1, err := net.ParseCIDR("127.0.0.1/24") + assert.NoError(t, err) + _, nw2, err := net.ParseCIDR("192.168.0.1/24") + assert.NoError(t, err) + options := []NamePolicyOption{ + WithExcludedCIDRs("127.0.0.1/24", "192.168.0.1/24"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + excludedIPRanges: []*net.IPNet{ + nw1, nw2, + }, + numberOfIPRangeConstraints: 2, + totalNumberOfExcludedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-permitted-ipsOrCIDRs-cidr": func(t *testing.T) test { + _, nw1, err := net.ParseCIDR("127.0.0.1/24") + assert.NoError(t, err) + _, nw2, err := net.ParseCIDR("192.168.0.31/32") + assert.NoError(t, err) + _, nw3, err := net.ParseCIDR("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128") + assert.NoError(t, err) + options := []NamePolicyOption{ + WithPermittedIPsOrCIDRs("127.0.0.1/24", "192.168.0.31", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + permittedIPRanges: []*net.IPNet{ + nw1, nw2, nw3, + }, + numberOfIPRangeConstraints: 3, + totalNumberOfPermittedConstraints: 3, + totalNumberOfConstraints: 3, + }, + wantErr: false, + } + }, + "ok/with-excluded-ipsOrCIDRs-cidr": func(t *testing.T) test { + _, nw1, err := net.ParseCIDR("127.0.0.1/24") + assert.NoError(t, err) + _, nw2, err := net.ParseCIDR("192.168.0.31/32") + assert.NoError(t, err) + _, nw3, err := net.ParseCIDR("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128") + assert.NoError(t, err) + options := []NamePolicyOption{ + WithExcludedIPsOrCIDRs("127.0.0.1/24", "192.168.0.31", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + excludedIPRanges: []*net.IPNet{ + nw1, nw2, nw3, + }, + numberOfIPRangeConstraints: 3, + totalNumberOfExcludedConstraints: 3, + totalNumberOfConstraints: 3, + }, + wantErr: false, + } + }, + "ok/with-permitted-emails": func(t *testing.T) test { + options := []NamePolicyOption{ + WithPermittedEmailAddresses("mail@local", "@example.com"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + permittedEmailAddresses: []string{"mail@local", "example.com"}, + numberOfEmailAddressConstraints: 2, + totalNumberOfPermittedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-excluded-emails": func(t *testing.T) test { + options := []NamePolicyOption{ + WithExcludedEmailAddresses("mail@local", "@example.com"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + excludedEmailAddresses: []string{"mail@local", "example.com"}, + numberOfEmailAddressConstraints: 2, + totalNumberOfExcludedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-permitted-uris": func(t *testing.T) test { + options := []NamePolicyOption{ + WithPermittedURIDomains("host.local", "*.example.com"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + permittedURIDomains: []string{"host.local", ".example.com"}, + numberOfURIDomainConstraints: 2, + totalNumberOfPermittedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-excluded-uris": func(t *testing.T) test { + options := []NamePolicyOption{ + WithExcludedURIDomains("host.local", "*.example.com"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + excludedURIDomains: []string{"host.local", ".example.com"}, + numberOfURIDomainConstraints: 2, + totalNumberOfExcludedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-permitted-principals": func(t *testing.T) test { + options := []NamePolicyOption{ + WithPermittedPrincipals("root", "ops"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + permittedPrincipals: []string{"root", "ops"}, + numberOfPrincipalConstraints: 2, + totalNumberOfPermittedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + "ok/with-excluded-principals": func(t *testing.T) test { + options := []NamePolicyOption{ + WithExcludedPrincipals("root", "ops"), + } + return test{ + options: options, + want: &NamePolicyEngine{ + excludedPrincipals: []string{"root", "ops"}, + numberOfPrincipalConstraints: 2, + totalNumberOfExcludedConstraints: 2, + totalNumberOfConstraints: 2, + }, + wantErr: false, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + got, err := New(tc.options...) + if (err != nil) != tc.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tc.wantErr) + return + } + if !cmp.Equal(tc.want, got, cmp.AllowUnexported(NamePolicyEngine{})) { + t.Errorf("New() diff =\n %s", cmp.Diff(tc.want, got, cmp.AllowUnexported(NamePolicyEngine{}))) + } + }) + } +} diff --git a/policy/ssh.go b/policy/ssh.go new file mode 100644 index 00000000..725f9b7b --- /dev/null +++ b/policy/ssh.go @@ -0,0 +1,9 @@ +package policy + +import ( + "golang.org/x/crypto/ssh" +) + +type SSHNamePolicyEngine interface { + IsSSHCertificateAllowed(cert *ssh.Certificate) error +} diff --git a/policy/validate.go b/policy/validate.go new file mode 100644 index 00000000..ee6f7e9c --- /dev/null +++ b/policy/validate.go @@ -0,0 +1,647 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// The code in this file is an adapted version of the code in +// https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +package policy + +import ( + "bytes" + "fmt" + "net" + "net/url" + "reflect" + "strings" + + "golang.org/x/net/idna" + + "go.step.sm/crypto/x509util" +) + +// validateNames verifies that all names are allowed. +func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailAddresses []string, uris []*url.URL, principals []string) error { + + // nothing to compare against; return early + if e.totalNumberOfConstraints == 0 { + return nil + } + + // TODO: set limit on total of all names validated? In x509 there's a limit on the number of comparisons + // that protects the CA from a DoS (i.e. many heavy comparisons). The x509 implementation takes + // this number as a total of all checks and keeps a (pointer to a) counter of the number of checks + // executed so far. + + // TODO: gather all errors, or return early? Currently we return early on the first wrong name; check might fail for multiple names. + // Perhaps make that an option? + for _, dns := range dnsNames { + // if there are DNS names to check, no DNS constraints set, but there are other permitted constraints, + // then return error, because DNS should be explicitly configured to be allowed in that case. In case there are + // (other) excluded constraints, we'll allow a DNS (implicit allow; currently). + if e.numberOfDNSDomainConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { + return &NamePolicyError{ + Reason: NotAllowed, + NameType: DNSNameType, + Name: dns, + detail: fmt.Sprintf("dns %q is not explicitly permitted by any constraint", dns), + } + } + didCutWildcard := false + parsedDNS := dns + if strings.HasPrefix(parsedDNS, "*.") { + parsedDNS = parsedDNS[1:] + didCutWildcard = true + } + // TODO(hs): fix this above; we need separate rule for Subject Common Name? + parsedDNS, err := idna.Lookup.ToASCII(parsedDNS) + if err != nil { + return &NamePolicyError{ + Reason: CannotParseDomain, + NameType: DNSNameType, + Name: dns, + detail: fmt.Sprintf("dns %q cannot be converted to ASCII", dns), + } + } + if didCutWildcard { + parsedDNS = "*" + parsedDNS + } + if _, ok := domainToReverseLabels(parsedDNS); !ok { // TODO(hs): this also fails with spaces + return &NamePolicyError{ + Reason: CannotParseDomain, + NameType: DNSNameType, + Name: dns, + detail: fmt.Sprintf("cannot parse dns %q", dns), + } + } + if err := checkNameConstraints(DNSNameType, dns, parsedDNS, + func(parsedName, constraint interface{}) (bool, error) { + return e.matchDomainConstraint(parsedName.(string), constraint.(string)) + }, e.permittedDNSDomains, e.excludedDNSDomains); err != nil { + return err + } + } + + for _, ip := range ips { + if e.numberOfIPRangeConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { + return &NamePolicyError{ + Reason: NotAllowed, + NameType: IPNameType, + Name: ip.String(), + detail: fmt.Sprintf("ip %q is not explicitly permitted by any constraint", ip.String()), + } + } + if err := checkNameConstraints(IPNameType, ip.String(), ip, + func(parsedName, constraint interface{}) (bool, error) { + return matchIPConstraint(parsedName.(net.IP), constraint.(*net.IPNet)) + }, e.permittedIPRanges, e.excludedIPRanges); err != nil { + return err + } + } + + for _, email := range emailAddresses { + if e.numberOfEmailAddressConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { + return &NamePolicyError{ + Reason: NotAllowed, + NameType: EmailNameType, + Name: email, + detail: fmt.Sprintf("email %q is not explicitly permitted by any constraint", email), + } + } + mailbox, ok := parseRFC2821Mailbox(email) + if !ok { + return &NamePolicyError{ + Reason: CannotParseRFC822Name, + NameType: EmailNameType, + Name: email, + detail: fmt.Sprintf("invalid rfc822Name %q", mailbox), + } + } + // According to RFC 5280, section 7.5, emails are considered to match if the local part is + // an exact match and the host (domain) part matches the ASCII representation (case-insensitive): + // https://datatracker.ietf.org/doc/html/rfc5280#section-7.5 + domainASCII, err := idna.ToASCII(mailbox.domain) + if err != nil { + return &NamePolicyError{ + Reason: CannotParseDomain, + NameType: EmailNameType, + Name: email, + detail: fmt.Errorf("cannot parse email domain %q: %w", email, err).Error(), + } + } + mailbox.domain = domainASCII + if err := checkNameConstraints(EmailNameType, email, mailbox, + func(parsedName, constraint interface{}) (bool, error) { + return e.matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string)) + }, e.permittedEmailAddresses, e.excludedEmailAddresses); err != nil { + return err + } + } + + // TODO(hs): fix internationalization for URIs (IRIs) + + for _, uri := range uris { + if e.numberOfURIDomainConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { + return &NamePolicyError{ + Reason: NotAllowed, + NameType: URINameType, + Name: uri.String(), + detail: fmt.Sprintf("uri %q is not explicitly permitted by any constraint", uri.String()), + } + } + // TODO(hs): ideally we'd like the uri.String() to be the original contents; now + // it's transformed into ASCII. Prevent that here? + if err := checkNameConstraints(URINameType, uri.String(), uri, + func(parsedName, constraint interface{}) (bool, error) { + return e.matchURIConstraint(parsedName.(*url.URL), constraint.(string)) + }, e.permittedURIDomains, e.excludedURIDomains); err != nil { + return err + } + } + + for _, principal := range principals { + if e.numberOfPrincipalConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 { + return &NamePolicyError{ + Reason: NotAllowed, + NameType: PrincipalNameType, + Name: principal, + detail: fmt.Sprintf("username principal %q is not explicitly permitted by any constraint", principal), + } + } + // TODO: some validation? I.e. allowed characters? + if err := checkNameConstraints(PrincipalNameType, principal, principal, + func(parsedName, constraint interface{}) (bool, error) { + return matchPrincipalConstraint(parsedName.(string), constraint.(string)) + }, e.permittedPrincipals, e.excludedPrincipals); err != nil { + return err + } + } + + // if all checks out, all SANs are allowed + return nil +} + +// validateCommonName verifies that the Subject Common Name is allowed +func (e *NamePolicyEngine) validateCommonName(commonName string) error { + + // nothing to compare against; return early + if e.totalNumberOfConstraints == 0 { + return nil + } + + // empty common names are not validated + if commonName == "" { + return nil + } + + if e.numberOfCommonNameConstraints > 0 { + // Check the Common Name using its dedicated matcher if constraints have been + // configured. If no error is returned from matching, the Common Name was + // explicitly allowed and nil is returned immediately. + if err := checkNameConstraints(CNNameType, commonName, commonName, + func(parsedName, constraint interface{}) (bool, error) { + return matchCommonNameConstraint(parsedName.(string), constraint.(string)) + }, e.permittedCommonNames, e.excludedCommonNames); err == nil { + return nil + } + } + + // When an error was returned or when no constraints were configured for Common Names, + // the Common Name should be validated against the other types of constraints too, + // according to what type it is. + dnsNames, ips, emails, uris := x509util.SplitSANs([]string{commonName}) + + err := e.validateNames(dnsNames, ips, emails, uris, []string{}) + + if pe, ok := err.(*NamePolicyError); ok { + // override the name type with CN + pe.NameType = CNNameType + } + + return err +} + +// checkNameConstraints checks that a name, of type nameType is permitted. +// The argument parsedName contains the parsed form of name, suitable for passing +// to the match function. +func checkNameConstraints( + nameType NameType, + name string, + parsedName interface{}, + match func(parsedName, constraint interface{}) (match bool, err error), + permitted, excluded interface{}) error { + + excludedValue := reflect.ValueOf(excluded) + + for i := 0; i < excludedValue.Len(); i++ { + constraint := excludedValue.Index(i).Interface() + match, err := match(parsedName, constraint) + if err != nil { + return &NamePolicyError{ + Reason: CannotMatchNameToConstraint, + NameType: nameType, + Name: name, + detail: err.Error(), + } + } + + if match { + return &NamePolicyError{ + Reason: NotAllowed, + NameType: nameType, + Name: name, + detail: fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint), + } + } + } + + permittedValue := reflect.ValueOf(permitted) + + ok := true + for i := 0; i < permittedValue.Len(); i++ { + constraint := permittedValue.Index(i).Interface() + var err error + if ok, err = match(parsedName, constraint); err != nil { + return &NamePolicyError{ + Reason: CannotMatchNameToConstraint, + NameType: nameType, + Name: name, + detail: err.Error(), + } + } + + if ok { + break + } + } + + if !ok { + return &NamePolicyError{ + Reason: NotAllowed, + NameType: nameType, + Name: name, + detail: fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name), + } + } + + return nil +} + +// domainToReverseLabels converts a textual domain name like foo.example.com to +// the list of labels in reverse order, e.g. ["com", "example", "foo"]. +func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) { + for len(domain) > 0 { + if i := strings.LastIndexByte(domain, '.'); i == -1 { + reverseLabels = append(reverseLabels, domain) + domain = "" + } else { + reverseLabels = append(reverseLabels, domain[i+1:]) + domain = domain[:i] + } + } + + if len(reverseLabels) > 0 && reverseLabels[0] == "" { + // An empty label at the end indicates an absolute value. + return nil, false + } + + for _, label := range reverseLabels { + if label == "" { + // Empty labels are otherwise invalid. + return nil, false + } + + for _, c := range label { + if c < 33 || c > 126 { + // Invalid character. + return nil, false + } + } + } + + return reverseLabels, true +} + +// rfc2821Mailbox represents a “mailbox” (which is an email address to most +// people) by breaking it into the “local” (i.e. before the '@') and “domain” +// parts. +type rfc2821Mailbox struct { + local, domain string +} + +// parseRFC2821Mailbox parses an email address into local and domain parts, +// based on the ABNF for a “Mailbox” from RFC 2821. According to RFC 5280, +// Section 4.2.1.6 that's correct for an rfc822Name from a certificate: “The +// format of an rfc822Name is a "Mailbox" as defined in RFC 2821, Section 4.1.2”. +func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) { + if in == "" { + return mailbox, false + } + + localPartBytes := make([]byte, 0, len(in)/2) + + if in[0] == '"' { + // Quoted-string = DQUOTE *qcontent DQUOTE + // non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127 + // qcontent = qtext / quoted-pair + // qtext = non-whitespace-control / + // %d33 / %d35-91 / %d93-126 + // quoted-pair = ("\" text) / obs-qp + // text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text + // + // (Names beginning with “obs-” are the obsolete syntax from RFC 2822, + // Section 4. Since it has been 16 years, we no longer accept that.) + in = in[1:] + QuotedString: + for { + if in == "" { + return mailbox, false + } + c := in[0] + in = in[1:] + + switch { + case c == '"': + break QuotedString + + case c == '\\': + // quoted-pair + if in == "" { + return mailbox, false + } + if in[0] == 11 || + in[0] == 12 || + (1 <= in[0] && in[0] <= 9) || + (14 <= in[0] && in[0] <= 127) { + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + } else { + return mailbox, false + } + + case c == 11 || + c == 12 || + // Space (char 32) is not allowed based on the + // BNF, but RFC 3696 gives an example that + // assumes that it is. Several “verified” + // errata continue to argue about this point. + // We choose to accept it. + c == 32 || + c == 33 || + c == 127 || + (1 <= c && c <= 8) || + (14 <= c && c <= 31) || + (35 <= c && c <= 91) || + (93 <= c && c <= 126): + // qtext + localPartBytes = append(localPartBytes, c) + + default: + return mailbox, false + } + } + } else { + // Atom ("." Atom)* + NextChar: + for len(in) > 0 { + // atext from RFC 2822, Section 3.2.4 + c := in[0] + + switch { + case c == '\\': + // Examples given in RFC 3696 suggest that + // escaped characters can appear outside of a + // quoted string. Several “verified” errata + // continue to argue the point. We choose to + // accept it. + in = in[1:] + if in == "" { + return mailbox, false + } + fallthrough + + case ('0' <= c && c <= '9') || + ('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || + c == '-' || c == '/' || c == '=' || c == '?' || + c == '^' || c == '_' || c == '`' || c == '{' || + c == '|' || c == '}' || c == '~' || c == '.': + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + + default: + break NextChar + } + } + + if len(localPartBytes) == 0 { + return mailbox, false + } + + // From RFC 3696, Section 3: + // “period (".") may also appear, but may not be used to start + // or end the local part, nor may two or more consecutive + // periods appear.” + twoDots := []byte{'.', '.'} + if localPartBytes[0] == '.' || + localPartBytes[len(localPartBytes)-1] == '.' || + bytes.Contains(localPartBytes, twoDots) { + return mailbox, false + } + } + + if in == "" || in[0] != '@' { + return mailbox, false + } + in = in[1:] + + // The RFC species a format for domains, but that's known to be + // violated in practice so we accept that anything after an '@' is the + // domain part. + if _, ok := domainToReverseLabels(in); !ok { + return mailbox, false + } + + mailbox.local = string(localPartBytes) + mailbox.domain = in + return mailbox, true +} + +// matchDomainConstraint matches a domain against the given constraint +func (e *NamePolicyEngine) matchDomainConstraint(domain, constraint string) (bool, error) { + // The meaning of zero length constraints is not specified, but this + // code follows NSS and accepts them as matching everything. + if constraint == "" { + return true, nil + } + + // A single whitespace seems to be considered a valid domain, but we don't allow it. + if domain == " " { + return false, nil + } + + // Block domains that start with just a period + if domain[0] == '.' { + return false, nil + } + + // Block wildcard domains that don't start with exactly "*." (i.e. double wildcards and such) + if domain[0] == '*' && domain[1] != '.' { + return false, nil + } + + // Check if the domain starts with a wildcard and return early if not allowed + if strings.HasPrefix(domain, "*.") && !e.allowLiteralWildcardNames { + return false, nil + } + + // Only allow asterisk at the start of the domain; we don't allow them as part of a domain label or as a (sub)domain label (currently) + if strings.LastIndex(domain, "*") > 0 { + return false, nil + } + + // Don't allow constraints with empty labels in any position + if strings.Contains(constraint, "..") { + return false, nil + } + + domainLabels, ok := domainToReverseLabels(domain) + if !ok { + return false, fmt.Errorf("cannot parse domain %q", domain) + } + + // RFC 5280 says that a leading period in a domain name means that at + // least one label must be prepended, but only for URI and email + // constraints, not DNS constraints. The code also supports that + // behavior for DNS constraints. In our adaptation of the original + // Go stdlib x509 Name Constraint implementation we look for exactly + // one subdomain, currently. + + mustHaveSubdomains := false + if constraint[0] == '.' { + mustHaveSubdomains = true + constraint = constraint[1:] + } + + constraintLabels, ok := domainToReverseLabels(constraint) + if !ok { + return false, fmt.Errorf("cannot parse domain constraint %q", constraint) + } + + expectedNumberOfLabels := len(constraintLabels) + if mustHaveSubdomains { + // we expect exactly one more label if it starts with the "canonical" x509 "wildcard": "." + // in the future we could extend this to support multiple additional labels and/or more + // complex matching. + expectedNumberOfLabels++ + } + + if len(domainLabels) != expectedNumberOfLabels { + return false, nil + } + + for i, constraintLabel := range constraintLabels { + if !strings.EqualFold(constraintLabel, domainLabels[i]) { + return false, nil + } + } + + return true, nil +} + +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func matchIPConstraint(ip net.IP, constraint *net.IPNet) (bool, error) { + + // TODO(hs): this is code from Go library, but I got some unexpected result: + // with permitted net 127.0.0.0/24, 127.0.0.1 is NOT allowed. When parsing 127.0.0.1 as net.IP + // which is in the IPAddresses slice, the underlying length is 16. The contraint.IP has a length + // of 4 instead. I currently don't believe that this is a bug in Go now, but why is it like that? + // Is there a difference because we're not operating on a sans []string slice? Or is the Go + // implementation stricter regarding IPv4 vs. IPv6? I've been bitten by some unfortunate differences + // between the two before (i.e. IPv4 in IPv6; IP SANS in ACME) + // if len(ip) != len(constraint.IP) { + // return false, nil + // } + + // for i := range ip { + // if mask := constraint.Mask[i]; ip[i]&mask != constraint.IP[i]&mask { + // return false, nil + // } + // } + + contained := constraint.Contains(ip) // TODO(hs): validate that this is the correct behavior; also check IPv4-in-IPv6 (again) + + return contained, nil +} + +// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go +func (e *NamePolicyEngine) matchEmailConstraint(mailbox rfc2821Mailbox, constraint string) (bool, error) { + if strings.Contains(constraint, "@") { + constraintMailbox, ok := parseRFC2821Mailbox(constraint) + if !ok { + return false, fmt.Errorf("cannot parse constraint %q", constraint) + } + return mailbox.local == constraintMailbox.local && strings.EqualFold(mailbox.domain, constraintMailbox.domain), nil + } + + // Otherwise the constraint is like a DNS constraint of the domain part + // of the mailbox. + return e.matchDomainConstraint(mailbox.domain, constraint) +} + +// matchURIConstraint matches an URL against a constraint +func (e *NamePolicyEngine) matchURIConstraint(uri *url.URL, constraint string) (bool, error) { + // From RFC 5280, Section 4.2.1.10: + // “a uniformResourceIdentifier that does not include an authority + // component with a host name specified as a fully qualified domain + // name (e.g., if the URI either does not include an authority + // component or includes an authority component in which the host name + // is specified as an IP address), then the application MUST reject the + // certificate.” + + host := uri.Host + if host == "" { + return false, fmt.Errorf("URI with empty host (%q) cannot be matched against constraints", uri.String()) + } + + // Block hosts with the wildcard character; no exceptions, also not when wildcards allowed. + if strings.Contains(host, "*") { + return false, fmt.Errorf("URI host %q cannot contain asterisk", uri.String()) + } + + if strings.Contains(host, ":") && !strings.HasSuffix(host, "]") { + var err error + host, _, err = net.SplitHostPort(host) + if err != nil { + return false, err + } + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") || + net.ParseIP(host) != nil { + return false, fmt.Errorf("URI with IP %q cannot be matched against constraints", uri.String()) + } + + // TODO(hs): add checks for scheme, path, etc.; either here, or in a different constraint matcher (to keep this one simple) + + return e.matchDomainConstraint(host, constraint) +} + +// matchPrincipalConstraint performs a string literal equality check against a constraint. +func matchPrincipalConstraint(principal, constraint string) (bool, error) { + // allow any plain principal when wildcard constraint is used + if constraint == "*" { + return true, nil + } + return strings.EqualFold(principal, constraint), nil +} + +// matchCommonNameConstraint performs a string literal equality check against constraint. +func matchCommonNameConstraint(commonName, constraint string) (bool, error) { + // wildcard constraint is (currently) not supported for common names + if constraint == "*" { + return false, nil + } + return strings.EqualFold(commonName, constraint), nil +} diff --git a/policy/x509.go b/policy/x509.go new file mode 100644 index 00000000..8b6c4de9 --- /dev/null +++ b/policy/x509.go @@ -0,0 +1,14 @@ +package policy + +import ( + "crypto/x509" + "net" +) + +type X509NamePolicyEngine interface { + IsX509CertificateAllowed(cert *x509.Certificate) error + IsX509CertificateRequestAllowed(csr *x509.CertificateRequest) error + AreSANsAllowed(sans []string) error + IsDNSAllowed(dns string) error + IsIPAllowed(ip net.IP) error +}