diff --git a/api/api.go b/api/api.go index f8d11ff5..a92b7902 100644 --- a/api/api.go +++ b/api/api.go @@ -18,19 +18,19 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/cli/crypto/tlsutil" ) // Authority is the interface implemented by a CA authority. type Authority interface { - Authorize(ott string) ([]interface{}, error) + Authorize(ott string) ([]provisioner.SignOption, error) GetTLSOptions() *tlsutil.TLSOptions Root(shasum string) (*x509.Certificate, error) - Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) + Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) - GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error) + GetProvisioners(cursor string, limit int) (provisioner.List, string, error) GetEncryptedKey(kid string) (string, error) GetRoots() (federation []*x509.Certificate, err error) GetFederation() ([]*x509.Certificate, error) @@ -161,11 +161,11 @@ type SignRequest struct { // ProvisionersResponse is the response object that returns the list of // provisioners. type ProvisionersResponse struct { - Provisioners []*authority.Provisioner `json:"provisioners"` - NextCursor string `json:"nextCursor"` + Provisioners provisioner.List `json:"provisioners"` + NextCursor string `json:"nextCursor"` } -// ProvisionerKeyResponse is the response object that returns the encryptoed key +// ProvisionerKeyResponse is the response object that returns the encrypted key // of a provisioner. type ProvisionerKeyResponse struct { Key string `json:"key"` @@ -266,18 +266,18 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { return } - signOpts := authority.SignOptions{ + opts := provisioner.Options{ NotBefore: body.NotBefore, NotAfter: body.NotAfter, } - extraOpts, err := h.Authority.Authorize(body.OTT) + signOpts, err := h.Authority.Authorize(body.OTT) if err != nil { WriteError(w, Unauthorized(err)) return } - cert, root, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, signOpts, extraOpts...) + cert, root, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { WriteError(w, Forbidden(err)) return diff --git a/api/api_test.go b/api/api_test.go index c4907b8d..80879ef5 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -24,7 +24,7 @@ import ( "time" "github.com/go-chi/chi" - "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/jose" @@ -410,22 +410,22 @@ func TestSignRequest_Validate(t *testing.T) { type mockAuthority struct { ret1, ret2 interface{} err error - authorize func(ott string) ([]interface{}, error) + authorize func(ott string) ([]provisioner.SignOption, error) getTLSOptions func() *tlsutil.TLSOptions root func(shasum string) (*x509.Certificate, error) - sign func(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) + sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) - getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error) + getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error) } -func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) { +func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) { if m.authorize != nil { return m.authorize(ott) } - return m.ret1.([]interface{}), m.err + return m.ret1.([]provisioner.SignOption), m.err } func (m *mockAuthority) GetTLSOptions() *tlsutil.TLSOptions { @@ -442,9 +442,9 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { return m.ret1.(*x509.Certificate), m.err } -func (m *mockAuthority) Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) { +func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) { if m.sign != nil { - return m.sign(cr, signOpts, extraOpts...) + return m.sign(cr, opts, signOpts...) } return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err } @@ -456,11 +456,11 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509. return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err } -func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) ([]*authority.Provisioner, string, error) { +func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { if m.getProvisioners != nil { return m.getProvisioners(nextCursor, limit) } - return m.ret1.([]*authority.Provisioner), m.ret2.(string), m.err + return m.ret1.(provisioner.List), m.ret2.(string), m.err } func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { @@ -597,7 +597,7 @@ func Test_caHandler_Sign(t *testing.T) { tests := []struct { name string input string - certAttrOpts []interface{} + certAttrOpts []provisioner.SignOption autherr error cert *x509.Certificate root *x509.Certificate @@ -617,7 +617,7 @@ func Test_caHandler_Sign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, - authorize: func(ott string) ([]interface{}, error) { + authorize: func(ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr }, getTLSOptions: func() *tlsutil.TLSOptions { @@ -723,14 +723,14 @@ func Test_caHandler_Provisioners(t *testing.T) { t.Fatal(err) } - p := []*authority.Provisioner{ - { + p := provisioner.List{ + &provisioner.JWK{ Type: "JWK", Name: "max", EncryptedKey: "abc", Key: &key, }, - { + &provisioner.JWK{ Type: "JWK", Name: "mariano", EncryptedKey: "def", diff --git a/authority/authority.go b/authority/authority.go index 5a0cf1ab..950f6e4a 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -4,10 +4,10 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" - "fmt" "sync" "time" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/x509util" ) @@ -16,18 +16,14 @@ const legacyAuthority = "step-certificate-authority" // Authority implements the Certificate Authority internal interface. type Authority struct { - config *Config - rootX509Certs []*x509.Certificate - intermediateIdentity *x509util.Identity - validateOnce bool - certificates *sync.Map - ottMap *sync.Map - startTime time.Time - provisionerIDIndex *sync.Map - encryptedKeyIndex *sync.Map - provisionerKeySetIndex *sync.Map - sortedProvisioners provisionerSlice - audiences []string + config *Config + rootX509Certs []*x509.Certificate + intermediateIdentity *x509util.Identity + validateOnce bool + certificates *sync.Map + ottMap *sync.Map + startTime time.Time + provisioners *provisioner.Collection // Do not re-initialize initOnce bool } @@ -39,31 +35,11 @@ func New(config *Config) (*Authority, error) { return nil, err } - // Get sorted provisioners - var sorted provisionerSlice - if config.AuthorityConfig != nil { - sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners) - if err != nil { - return nil, err - } - } - - // Define audiences: legacy + possible urls without the ports. - // The CA might have proxies in front so we cannot rely on the port. - audiences := []string{legacyAuthority} - for _, name := range config.DNSNames { - audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name)) - } - var a = &Authority{ - config: config, - certificates: new(sync.Map), - ottMap: new(sync.Map), - provisionerIDIndex: new(sync.Map), - encryptedKeyIndex: new(sync.Map), - provisionerKeySetIndex: new(sync.Map), - sortedProvisioners: sorted, - audiences: audiences, + config: config, + certificates: new(sync.Map), + ottMap: new(sync.Map), + provisioners: provisioner.NewCollection(config.getAudiences()), } if err := a.init(); err != nil { return nil, err @@ -120,14 +96,15 @@ func (a *Authority) init() error { } } + // Store all the provisioners for _, p := range a.config.AuthorityConfig.Provisioners { - a.provisionerIDIndex.Store(p.ID(), p) - if len(p.EncryptedKey) != 0 { - a.encryptedKeyIndex.Store(p.Key.KeyID, p.EncryptedKey) + if err := a.provisioners.Store(p); err != nil { + return err } } - a.startTime = time.Now() + // JWT numeric dates are seconds. + a.startTime = time.Now().Truncate(time.Second) // Set flag indicating that initialization has been completed, and should // not be repeated. a.initOnce = true diff --git a/authority/authority_test.go b/authority/authority_test.go index 1020f808..e008b22d 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" stepJOSE "github.com/smallstep/cli/jose" ) @@ -16,22 +17,22 @@ func testAuthority(t *testing.T) *Authority { clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) disableRenewal := true - p := []*Provisioner{ - { + p := provisioner.List{ + &provisioner.JWK{ Name: "Max", Type: "JWK", Key: maxjwk, }, - { + &provisioner.JWK{ Name: "step-cli", Type: "JWK", Key: clijwk, }, - { + &provisioner.JWK{ Name: "dev", Type: "JWK", Key: maxjwk, - Claims: &ProvisionerClaims{ + Claims: &provisioner.Claims{ DisableRenewal: &disableRenewal, }, }, @@ -113,24 +114,18 @@ func TestAuthorityNew(t *testing.T) { assert.True(t, auth.initOnce) assert.NotNil(t, auth.intermediateIdentity) for _, p := range tc.config.AuthorityConfig.Provisioners { - _p, ok := auth.provisionerIDIndex.Load(p.ID()) + _p, ok := auth.provisioners.Load(p.GetID()) assert.True(t, ok) assert.Equals(t, p, _p) - if len(p.EncryptedKey) > 0 { - key, ok := auth.encryptedKeyIndex.Load(p.Key.KeyID) + if kid, encryptedKey, ok := p.GetEncryptedKey(); ok { + key, ok := auth.provisioners.LoadEncryptedKey(kid) assert.True(t, ok) - assert.Equals(t, p.EncryptedKey, key) + assert.Equals(t, encryptedKey, key) } } // sanity check - _, ok = auth.provisionerIDIndex.Load("fooo") + _, ok = auth.provisioners.Load("fooo") assert.False(t, ok) - - assert.Equals(t, auth.audiences, []string{ - "step-certificate-authority", - "https://127.0.0.1/sign", - "https://127.0.0.1/1.0/sign", - }) } } }) diff --git a/authority/authorize.go b/authority/authorize.go index 5566b17f..d0d04121 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -2,14 +2,13 @@ package authority import ( "crypto/x509" - "encoding/asn1" "net/http" - "net/url" + "strings" "time" "github.com/pkg/errors" - "github.com/smallstep/cli/crypto/x509util" - "gopkg.in/square/go-jose.v2/jwt" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/jose" ) type idUsed struct { @@ -17,49 +16,21 @@ type idUsed struct { Subject string `json:"sub,omitempty"` } -// Claims extends jwt.Claims with step attributes. +// Claims extends jose.Claims with step attributes. type Claims struct { - jwt.Claims - SANs []string `json:"sans,omitempty"` -} - -// matchesAudience returns true if A and B share at least one element. -func matchesAudience(as, bs []string) bool { - if len(bs) == 0 || len(as) == 0 { - return false - } - - for _, b := range bs { - for _, a := range as { - if b == a || stripPort(a) == stripPort(b) { - return true - } - } - } - return false -} - -// stripPort attempts to strip the port from the given url. If parsing the url -// produces errors it will just return the passed argument. -func stripPort(rawurl string) string { - u, err := url.Parse(rawurl) - if err != nil { - return rawurl - } - u.Host = u.Hostname() - return u.String() + jose.Claims + SANs []string `json:"sans,omitempty"` + Email string `json:"email,omitempty"` + Nonce string `json:"nonce,omitempty"` } // Authorize authorizes a signature request by validating and authenticating // a OTT that must be sent w/ the request. -func (a *Authority) Authorize(ott string) ([]interface{}, error) { - var ( - errContext = map[string]interface{}{"ott": ott} - claims = Claims{} - ) +func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) { + var errContext = map[string]interface{}{"ott": ott} // Validate payload - token, err := jwt.ParseSigned(ott) + token, err := jose.ParseSigned(ott) if err != nil { return nil, &apiError{errors.Wrapf(err, "authorize: error parsing token"), http.StatusUnauthorized, errContext} @@ -68,86 +39,52 @@ func (a *Authority) Authorize(ott string) ([]interface{}, error) { // Get claims w/out verification. We need to look up the provisioner // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. + var claims Claims if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, &apiError{err, http.StatusUnauthorized, errContext} } - kid := token.Headers[0].KeyID // JWT will only have 1 header. - if len(kid) == 0 { - return nil, &apiError{errors.New("authorize: token KeyID cannot be empty"), - http.StatusUnauthorized, errContext} - } - pid := claims.Issuer + ":" + kid - val, ok := a.provisionerIDIndex.Load(pid) - if !ok { - return nil, &apiError{errors.Errorf("authorize: provisioner with id %s not found", pid), - http.StatusUnauthorized, errContext} - } - p, ok := val.(*Provisioner) - if !ok { - return nil, &apiError{errors.Errorf("authorize: invalid provisioner type"), - http.StatusInternalServerError, errContext} - } - - if err = token.Claims(p.Key, &claims); err != nil { - return nil, &apiError{err, http.StatusUnauthorized, errContext} - } - - // According to "rfc7519 JSON Web Token" acceptable skew should be no - // more than a few minutes. - if err = claims.ValidateWithLeeway(jwt.Expected{ - Issuer: p.Name, - }, time.Minute); err != nil { - return nil, &apiError{errors.Wrapf(err, "authorize: invalid token"), - http.StatusUnauthorized, errContext} - } // Do not accept tokens issued before the start of the ca. // This check is meant as a stopgap solution to the current lack of a persistence layer. if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if claims.IssuedAt > 0 && claims.IssuedAt.Time().Before(a.startTime) { - return nil, &apiError{errors.New("token issued before the bootstrap of certificate authority"), + return nil, &apiError{errors.New("authorize: token issued before the bootstrap of certificate authority"), http.StatusUnauthorized, errContext} } } - if !matchesAudience(claims.Audience, a.audiences) { - return nil, &apiError{errors.New("authorize: token audience invalid"), http.StatusUnauthorized, - errContext} - } - - if claims.Subject == "" { - return nil, &apiError{errors.New("authorize: token subject cannot be empty"), + // This method will also validate the audiences for JWK provisioners. + p, ok := a.provisioners.LoadByToken(token, &claims.Claims) + if !ok { + return nil, &apiError{ + errors.Errorf("authorize: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")), http.StatusUnauthorized, errContext} } - // NOTE: This is for backwards compatibility with older versions of cli - // and certificates. Older versions added the token subject as the only SAN - // in a CSR by default. - if len(claims.SANs) == 0 { - claims.SANs = []string{claims.Subject} - } - dnsNames, ips := x509util.SplitSANs(claims.SANs) - if err != nil { - return nil, err - } - - signOps := []interface{}{ - &commonNameClaim{claims.Subject}, - &dnsNamesClaim{dnsNames}, - &ipAddressesClaim{ips}, - p, - } - // Store the token to protect against reuse. - if _, ok := a.ottMap.LoadOrStore(claims.ID, &idUsed{ - UsedAt: time.Now().Unix(), - Subject: claims.Subject, - }); ok { - return nil, &apiError{errors.Errorf("token already used"), http.StatusUnauthorized, - errContext} + var reuseKey string + switch p.GetType() { + case provisioner.TypeJWK: + reuseKey = claims.ID + case provisioner.TypeOIDC: + reuseKey = claims.Nonce + } + if reuseKey != "" { + if _, ok := a.ottMap.LoadOrStore(reuseKey, &idUsed{ + UsedAt: time.Now().Unix(), + Subject: claims.Subject, + }); ok { + return nil, &apiError{errors.Errorf("authorize: token already used"), http.StatusUnauthorized, errContext} + } } - return signOps, nil + // Call the provisioner Authorize method to get the signing options + opts, err := p.Authorize(ott) + if err != nil { + return nil, &apiError{errors.Wrap(err, "authorize"), http.StatusUnauthorized, errContext} + } + + return opts, nil } // authorizeRenewal tries to locate the step provisioner extension, and checks @@ -157,46 +94,20 @@ func (a *Authority) Authorize(ott string) ([]interface{}, error) { // TODO(mariano): should we authorize by default? func (a *Authority) authorizeRenewal(crt *x509.Certificate) error { errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()} - for _, e := range crt.Extensions { - if e.Id.Equal(stepOIDProvisioner) { - var provisioner stepProvisionerASN1 - if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { - return &apiError{ - err: errors.Wrap(err, "error decoding step provisioner extension"), - code: http.StatusInternalServerError, - context: errContext, - } - } - - // Look for the provisioner, if it cannot be found, renewal will not - // be authorized. - pid := string(provisioner.Name) + ":" + string(provisioner.CredentialID) - val, ok := a.provisionerIDIndex.Load(pid) - if !ok { - return &apiError{ - err: errors.Errorf("not found: provisioner %s", pid), - code: http.StatusUnauthorized, - context: errContext, - } - } - p, ok := val.(*Provisioner) - if !ok { - return &apiError{ - err: errors.Errorf("invalid type: provisioner %s, type %T", pid, val), - code: http.StatusInternalServerError, - context: errContext, - } - } - if p.Claims.IsDisableRenewal() { - return &apiError{ - err: errors.Errorf("renew disabled: provisioner %s", pid), - code: http.StatusUnauthorized, - context: errContext, - } - } - return nil + p, ok := a.provisioners.LoadByCertificate(crt) + if !ok { + return &apiError{ + err: errors.New("provisioner not found"), + code: http.StatusUnauthorized, + context: errContext, + } + } + if err := p.AuthorizeRenewal(crt); err != nil { + return &apiError{ + err: err, + code: http.StatusUnauthorized, + context: errContext, } } - return nil } diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 8ccd7a4d..64a9dc63 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -7,100 +7,52 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/cli/crypto/keys" - stepJOSE "github.com/smallstep/cli/jose" - jose "gopkg.in/square/go-jose.v2" - "gopkg.in/square/go-jose.v2/jwt" + "github.com/smallstep/cli/crypto/randutil" + "github.com/smallstep/cli/jose" ) -func TestMatchesAudience(t *testing.T) { - type matchesTest struct { - a, b []string - exp bool +func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { + sig, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), + ) + if err != nil { + return "", err } - tests := map[string]matchesTest{ - "false arg1 empty": { - a: []string{}, - b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, - exp: false, - }, - "false arg2 empty": { - a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, - b: []string{}, - exp: false, - }, - "false arg1,arg2 empty": { - a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, - b: []string{"step-gateway", "step-cli"}, - exp: false, - }, - "false": { - a: []string{"step-gateway", "step-cli"}, - b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, - exp: false, - }, - "true": { - a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"}, - b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, - exp: true, - }, - "true,portsA": { - a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"}, - b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, - exp: true, - }, - "true,portsB": { - a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"}, - b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"}, - exp: true, - }, - "true,portsAB": { - a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"}, - b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"}, - exp: true, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b)) - }) - } -} -func TestStripPort(t *testing.T) { - type args struct { - rawurl string + id, err := randutil.ASCII(64) + if err != nil { + return "", err } - tests := []struct { - name string - args args - want string + + claims := struct { + jose.Claims + SANS []string `json:"sans"` }{ - {"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"}, - {"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"}, - {"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := stripPort(tt.args.rawurl); got != tt.want { - t.Errorf("stripPort() = %v, want %v", got, tt.want) - } - }) + Claims: jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + }, + SANS: sans, } + return jose.Signed(sig).Claims(claims).CompactSerialize() } func TestAuthorize(t *testing.T) { a := testAuthority(t) - jwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_priv.jwk", - stepJOSE.WithPassword([]byte("pass"))) - assert.FatalError(t, err) - sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, - (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) + key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) + // Invalid keys + keyNoKid := &jose.JSONWebKey{Key: key.Key, KeyID: ""} + keyBadKid := &jose.JSONWebKey{Key: key.Key, KeyID: "foo"} now := time.Now() - validIssuer := "step-cli" validAudience := []string{"https://test.ca.smallstep.com/sign"} @@ -120,100 +72,37 @@ func TestAuthorize(t *testing.T) { } }, "fail empty key id": func(t *testing.T) *authorizeTest { - _sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, - (&jose.SignerOptions{}).WithType("JWT")) - assert.FatalError(t, err) - cl := jwt.Claims{ - Subject: "test.smallstep.com", - Issuer: validIssuer, - NotBefore: jwt.NewNumericDate(now), - Expiry: jwt.NewNumericDate(now.Add(time.Minute)), - Audience: validAudience, - ID: "43", - } - raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize() + raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyNoKid) assert.FatalError(t, err) return &authorizeTest{ auth: a, ott: raw, - err: &apiError{errors.New("authorize: token KeyID cannot be empty"), + err: &apiError{errors.New("authorize: provisioner not found or invalid audience"), http.StatusUnauthorized, context{"ott": raw}}, } }, "fail provisioner not found": func(t *testing.T) *authorizeTest { - _sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, - (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo")) - assert.FatalError(t, err) - - cl := jwt.Claims{ - Subject: "test.smallstep.com", - Issuer: validIssuer, - NotBefore: jwt.NewNumericDate(now), - Expiry: jwt.NewNumericDate(now.Add(time.Minute)), - Audience: validAudience, - ID: "43", - } - raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize() + raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyBadKid) assert.FatalError(t, err) return &authorizeTest{ auth: a, ott: raw, - err: &apiError{errors.New("authorize: provisioner with id step-cli:foo not found"), + err: &apiError{errors.New("authorize: provisioner not found or invalid audience"), http.StatusUnauthorized, context{"ott": raw}}, } }, - "fail invalid provisioner": func(t *testing.T) *authorizeTest { - _a := testAuthority(t) - - _sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, - (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo")) - assert.FatalError(t, err) - - _a.provisionerIDIndex.Store(validIssuer+":foo", "42") - - cl := jwt.Claims{ - Subject: "test.smallstep.com", - Issuer: validIssuer, - NotBefore: jwt.NewNumericDate(now), - Expiry: jwt.NewNumericDate(now.Add(time.Minute)), - Audience: validAudience, - ID: "43", - } - raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize() - assert.FatalError(t, err) - return &authorizeTest{ - auth: _a, - ott: raw, - err: &apiError{errors.New("authorize: invalid provisioner type"), - http.StatusInternalServerError, context{"ott": raw}}, - } - }, "fail invalid issuer": func(t *testing.T) *authorizeTest { - cl := jwt.Claims{ - Subject: "subject", - Issuer: "invalid-issuer", - NotBefore: jwt.NewNumericDate(now), - Expiry: jwt.NewNumericDate(now.Add(time.Minute)), - Audience: validAudience, - } - raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + raw, err := generateToken("test.smallstep.com", "invalid-issuer", validAudience[0], nil, now, key) assert.FatalError(t, err) return &authorizeTest{ auth: a, ott: raw, - err: &apiError{errors.New("authorize: provisioner with id invalid-issuer:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc not found"), + err: &apiError{errors.New("authorize: provisioner not found or invalid audience"), http.StatusUnauthorized, context{"ott": raw}}, } }, "fail empty subject": func(t *testing.T) *authorizeTest { - cl := jwt.Claims{ - Subject: "", - Issuer: validIssuer, - NotBefore: jwt.NewNumericDate(now), - Expiry: jwt.NewNumericDate(now.Add(time.Minute)), - Audience: validAudience, - } - raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + raw, err := generateToken("", validIssuer, validAudience[0], nil, now, key) assert.FatalError(t, err) return &authorizeTest{ auth: a, @@ -223,64 +112,34 @@ func TestAuthorize(t *testing.T) { } }, "fail verify-sig-failure": func(t *testing.T) *authorizeTest { - _, priv2, err := keys.GenerateDefaultKeyPair() - assert.FatalError(t, err) - invalidKeySig, err := jose.NewSigner(jose.SigningKey{ - Algorithm: jose.ES256, - Key: priv2, - }, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) - assert.FatalError(t, err) - cl := jwt.Claims{ - Subject: "test.smallstep.com", - Issuer: validIssuer, - NotBefore: jwt.NewNumericDate(now), - Expiry: jwt.NewNumericDate(now.Add(time.Minute)), - Audience: validAudience, - } - raw, err := jwt.Signed(invalidKeySig).Claims(cl).CompactSerialize() + raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key) assert.FatalError(t, err) return &authorizeTest{ auth: a, - ott: raw, - err: &apiError{errors.New("square/go-jose: error in cryptographic primitive"), - http.StatusUnauthorized, context{"ott": raw}}, + ott: raw + "00", + err: &apiError{errors.New("authorize: error parsing claims: square/go-jose: error in cryptographic primitive"), + http.StatusUnauthorized, context{"ott": raw + "00"}}, } }, "fail token-already-used": func(t *testing.T) *authorizeTest { - cl := jwt.Claims{ - Subject: "test.smallstep.com", - Issuer: validIssuer, - NotBefore: jwt.NewNumericDate(now), - Expiry: jwt.NewNumericDate(now.Add(time.Minute)), - Audience: validAudience, - ID: "42", - } - raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key) assert.FatalError(t, err) _, err = a.Authorize(raw) assert.FatalError(t, err) return &authorizeTest{ auth: a, ott: raw, - err: &apiError{errors.New("token already used"), + err: &apiError{errors.New("authorize: token already used"), http.StatusUnauthorized, context{"ott": raw}}, } }, "ok": func(t *testing.T) *authorizeTest { - cl := jwt.Claims{ - Subject: "test.smallstep.com", - Issuer: validIssuer, - NotBefore: jwt.NewNumericDate(now), - Expiry: jwt.NewNumericDate(now.Add(time.Minute)), - Audience: validAudience, - ID: "43", - } - raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize() + raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key) assert.FatalError(t, err) return &authorizeTest{ auth: a, ott: raw, - res: []interface{}{"1", "2", "3", "4"}, + res: []interface{}{"1", "2", "3", "4", "5", "6"}, } }, } diff --git a/authority/claims.go b/authority/claims.go deleted file mode 100644 index 9f38a810..00000000 --- a/authority/claims.go +++ /dev/null @@ -1,117 +0,0 @@ -package authority - -import ( - "net" - "reflect" - "time" - - "github.com/pkg/errors" - x509 "github.com/smallstep/cli/pkg/x509" -) - -// certClaim interface is implemented by types used to validate specific claims in a -// certificate request. -type certClaim interface { - Valid(crt *x509.Certificate) error -} - -// ValidateClaims returns nil if all the claims are validated, it will return -// the first error if a claim fails. -func validateClaims(crt *x509.Certificate, claims []certClaim) (err error) { - for _, c := range claims { - if err = c.Valid(crt); err != nil { - return err - } - } - return -} - -// commonNameClaim validates the common name of a certificate request. -type commonNameClaim struct { - name string -} - -// Valid checks that certificate request common name matches the one configured. -func (c *commonNameClaim) Valid(crt *x509.Certificate) error { - if crt.Subject.CommonName == "" { - return errors.New("common name cannot be empty") - } - if crt.Subject.CommonName != c.name { - return errors.Errorf("common name claim failed - got %s, want %s", crt.Subject.CommonName, c.name) - } - return nil -} - -type dnsNamesClaim struct { - names []string -} - -// Valid checks that certificate request DNS Names match those configured in -// the bootstrap (token) flow. -func (c *dnsNamesClaim) Valid(crt *x509.Certificate) error { - tokMap := make(map[string]int) - for _, e := range c.names { - tokMap[e] = 1 - } - crtMap := make(map[string]int) - for _, e := range crt.DNSNames { - crtMap[e] = 1 - } - if !reflect.DeepEqual(tokMap, crtMap) { - return errors.Errorf("DNS names claim failed - got %s, want %s", crt.DNSNames, c.names) - } - return nil -} - -type ipAddressesClaim struct { - ips []net.IP -} - -// Valid checks that certificate request IP Addresses match those configured in -// the bootstrap (token) flow. -func (c *ipAddressesClaim) Valid(crt *x509.Certificate) error { - tokMap := make(map[string]int) - for _, e := range c.ips { - tokMap[e.String()] = 1 - } - crtMap := make(map[string]int) - for _, e := range crt.IPAddresses { - crtMap[e.String()] = 1 - } - if !reflect.DeepEqual(tokMap, crtMap) { - return errors.Errorf("IP Addresses claim failed - got %v, want %v", crt.IPAddresses, c.ips) - } - return nil -} - -// certTemporalClaim validates the certificate temporal validity settings. -type certTemporalClaim struct { - min time.Duration - max time.Duration -} - -// Validate validates the certificate temporal validity settings. -func (ctc *certTemporalClaim) Valid(crt *x509.Certificate) error { - var ( - na = crt.NotAfter - nb = crt.NotBefore - d = na.Sub(nb) - now = time.Now() - ) - - if na.Before(now) { - return errors.Errorf("NotAfter: %v cannot be in the past", na) - } - if na.Before(nb) { - return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb) - } - if d < ctc.min { - return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v", - d, ctc.min) - } - if d > ctc.max { - return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v", - d, ctc.max) - } - return nil -} diff --git a/authority/claims_test.go b/authority/claims_test.go deleted file mode 100644 index d9c9d768..00000000 --- a/authority/claims_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package authority - -import ( - "crypto/x509/pkix" - "net" - "testing" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - x509 "github.com/smallstep/cli/pkg/x509" -) - -func TestCommonNameClaim_Valid(t *testing.T) { - tests := map[string]struct { - cnc certClaim - crt *x509.Certificate - err error - }{ - "empty-common-name": { - cnc: &commonNameClaim{name: "foo"}, - crt: &x509.Certificate{}, - err: errors.New("common name cannot be empty"), - }, - "wrong-common-name": { - cnc: &commonNameClaim{name: "foo"}, - crt: &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}}, - err: errors.New("common name claim failed - got bar, want foo"), - }, - "ok": { - cnc: &commonNameClaim{name: "foo"}, - crt: &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}}, - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - err := tc.cnc.Valid(tc.crt) - if err != nil { - if assert.NotNil(t, tc.err) { - assert.Equals(t, tc.err.Error(), err.Error()) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestIPAddressesClaim_Valid(t *testing.T) { - tests := map[string]struct { - iac certClaim - crt *x509.Certificate - err error - }{ - "unexpected-ip-in-crt": { - iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}}, - crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}}, - err: errors.New("IP Addresses claim failed - got [127.0.0.1 1.1.1.1], want [127.0.0.1]"), - }, - "missing-ip-in-crt": { - iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}}, - crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}}, - err: errors.New("IP Addresses claim failed - got [127.0.0.1], want [127.0.0.1 1.1.1.1]"), - }, - "invalid-matcher-nonempty-ips": { - iac: &ipAddressesClaim{ips: []net.IP{}}, - crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}}, - err: errors.New("IP Addresses claim failed - got [127.0.0.1], want []"), - }, - "ok": { - iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}}, - crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}}, - }, - "ok-multiple-identical-ip-entries": { - iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}}, - crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1")}}, - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - err := tc.iac.Valid(tc.crt) - if err != nil { - if assert.NotNil(t, tc.err) { - assert.Equals(t, tc.err.Error(), err.Error()) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestDNSNamesClaim_Valid(t *testing.T) { - tests := map[string]struct { - dnc certClaim - crt *x509.Certificate - err error - }{ - "unexpected-dns-name-in-crt": { - dnc: &dnsNamesClaim{names: []string{"foo"}}, - crt: &x509.Certificate{DNSNames: []string{"foo", "bar"}}, - err: errors.New("DNS names claim failed - got [foo bar], want [foo]"), - }, - "ok": { - dnc: &dnsNamesClaim{names: []string{"foo", "bar"}}, - crt: &x509.Certificate{DNSNames: []string{"bar", "foo"}}, - }, - "missing-dns-name-in-crt": { - dnc: &dnsNamesClaim{names: []string{"foo", "bar"}}, - crt: &x509.Certificate{DNSNames: []string{"foo"}}, - err: errors.New("DNS names claim failed - got [foo], want [foo bar]"), - }, - "ok-multiple-identical-dns-entries": { - dnc: &dnsNamesClaim{names: []string{"foo"}}, - crt: &x509.Certificate{DNSNames: []string{"foo", "foo", "foo"}}, - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - err := tc.dnc.Valid(tc.crt) - if err != nil { - if assert.NotNil(t, tc.err) { - assert.Equals(t, tc.err.Error(), err.Error()) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} diff --git a/authority/config.go b/authority/config.go index 3bc8e810..406fd437 100644 --- a/authority/config.go +++ b/authority/config.go @@ -2,11 +2,13 @@ package authority import ( "encoding/json" + "fmt" "net" "os" "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/x509util" ) @@ -25,10 +27,10 @@ var ( Renegotiation: false, } defaultDisableRenewal = false - globalProvisionerClaims = ProvisionerClaims{ - MinTLSDur: &Duration{5 * time.Minute}, - MaxTLSDur: &Duration{24 * time.Hour}, - DefaultTLSDur: &Duration{24 * time.Hour}, + globalProvisionerClaims = provisioner.Claims{ + MinTLSDur: &provisioner.Duration{5 * time.Minute}, + MaxTLSDur: &provisioner.Duration{24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{24 * time.Hour}, DisableRenewal: &defaultDisableRenewal, } ) @@ -50,16 +52,15 @@ type Config struct { // AuthConfig represents the configuration options for the authority. type AuthConfig struct { - Provisioners []*Provisioner `json:"provisioners,omitempty"` - Template *x509util.ASN1DN `json:"template,omitempty"` - Claims *ProvisionerClaims `json:"claims,omitempty"` - DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` + Provisioners provisioner.List `json:"provisioners"` + Template *x509util.ASN1DN `json:"template,omitempty"` + Claims *provisioner.Claims `json:"claims,omitempty"` + DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` } // Validate validates the authority configuration. -func (c *AuthConfig) Validate() error { +func (c *AuthConfig) Validate(audiences []string) error { var err error - if c == nil { return errors.New("authority cannot be undefined") } @@ -70,11 +71,18 @@ func (c *AuthConfig) Validate() error { if c.Claims, err = c.Claims.Init(&globalProvisionerClaims); err != nil { return err } + + // Initialize provisioners + config := provisioner.Config{ + Claims: *c.Claims, + Audiences: audiences, + } for _, p := range c.Provisioners { - if err := p.Init(c.Claims); err != nil { + if err := p.Init(config); err != nil { return err } } + if c.Template == nil { c.Template = &x509util.ASN1DN{} } @@ -153,5 +161,16 @@ func (c *Config) Validate() error { c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation } - return c.AuthorityConfig.Validate() + return c.AuthorityConfig.Validate(c.getAudiences()) +} + +// getAudiences returns the legacy and possible urls without the ports that will +// be used as the default provisioner audiences. The CA might have proxies in +// front so we cannot rely on the port. +func (c *Config) getAudiences() []string { + audiences := []string{legacyAuthority} + for _, name := range c.DNSNames { + audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name)) + } + return audiences } diff --git a/authority/config_test.go b/authority/config_test.go index 01cea2a1..ca5de829 100644 --- a/authority/config_test.go +++ b/authority/config_test.go @@ -5,6 +5,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/x509util" stepJOSE "github.com/smallstep/cli/jose" @@ -17,13 +18,13 @@ func TestConfigValidate(t *testing.T) { clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) ac := &AuthConfig{ - Provisioners: []*Provisioner{ - { + Provisioners: provisioner.List{ + &provisioner.JWK{ Name: "Max", Type: "JWK", Key: maxjwk, }, - { + &provisioner.JWK{ Name: "step-cli", Type: "JWK", Key: clijwk, @@ -229,13 +230,13 @@ func TestAuthConfigValidate(t *testing.T) { assert.FatalError(t, err) clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) - p := []*Provisioner{ - { + p := provisioner.List{ + &provisioner.JWK{ Name: "Max", Type: "JWK", Key: maxjwk, }, - { + &provisioner.JWK{ Name: "step-cli", Type: "JWK", Key: clijwk, @@ -263,9 +264,9 @@ func TestAuthConfigValidate(t *testing.T) { "fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest { return AuthConfigValidateTest{ ac: &AuthConfig{ - Provisioners: []*Provisioner{ - {Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, - {Name: "foo", Key: &jose.JSONWebKey{}}, + Provisioners: provisioner.List{ + &provisioner.JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, + &provisioner.JWK{Name: "foo", Key: &jose.JSONWebKey{}}, }, }, err: errors.New("provisioner type cannot be empty"), @@ -293,7 +294,7 @@ func TestAuthConfigValidate(t *testing.T) { for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) - err := tc.ac.Validate() + err := tc.ac.Validate([]string{}) if err != nil { if assert.NotNil(t, tc.err) { assert.Equals(t, tc.err.Error(), err.Error()) diff --git a/authority/provisioner.go b/authority/provisioner/claims.go similarity index 55% rename from authority/provisioner.go rename to authority/provisioner/claims.go index 6dd1b1ac..2fc68397 100644 --- a/authority/provisioner.go +++ b/authority/provisioner/claims.go @@ -1,17 +1,14 @@ -package authority +package provisioner import ( "time" "github.com/pkg/errors" - "github.com/smallstep/cli/crypto/x509util" - - jose "gopkg.in/square/go-jose.v2" ) -// ProvisionerClaims so that individual provisioners can override global claims. -type ProvisionerClaims struct { - globalClaims *ProvisionerClaims +// Claims so that individual provisioners can override global claims. +type Claims struct { + globalClaims *Claims MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` @@ -19,19 +16,18 @@ type ProvisionerClaims struct { } // Init initializes and validates the individual provisioner claims. -func (pc *ProvisionerClaims) Init(global *ProvisionerClaims) (*ProvisionerClaims, error) { +func (pc *Claims) Init(global *Claims) (*Claims, error) { if pc == nil { - pc = &ProvisionerClaims{} + pc = &Claims{} } pc.globalClaims = global - err := pc.Validate() - return pc, err + return pc, pc.Validate() } // DefaultTLSCertDuration returns the default TLS cert duration for the // provisioner. If the default is not set within the provisioner, then the global // default from the authority configuration will be used. -func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration { +func (pc *Claims) DefaultTLSCertDuration() time.Duration { if pc.DefaultTLSDur == nil || pc.DefaultTLSDur.Duration == 0 { return pc.globalClaims.DefaultTLSCertDuration() } @@ -41,7 +37,7 @@ func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration { // MinTLSCertDuration returns the minimum TLS cert duration for the provisioner. // If the minimum is not set within the provisioner, then the global // minimum from the authority configuration will be used. -func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration { +func (pc *Claims) MinTLSCertDuration() time.Duration { if pc.MinTLSDur == nil || pc.MinTLSDur.Duration == 0 { return pc.globalClaims.MinTLSCertDuration() } @@ -51,7 +47,7 @@ func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration { // MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner. // If the maximum is not set within the provisioner, then the global // maximum from the authority configuration will be used. -func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration { +func (pc *Claims) MaxTLSCertDuration() time.Duration { if pc.MaxTLSDur == nil || pc.MaxTLSDur.Duration == 0 { return pc.globalClaims.MaxTLSCertDuration() } @@ -61,7 +57,7 @@ func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration { // IsDisableRenewal returns if the renewal flow is disabled for the // provisioner. If the property is not set within the provisioner, then the // global value from the authority configuration will be used. -func (pc *ProvisionerClaims) IsDisableRenewal() bool { +func (pc *Claims) IsDisableRenewal() bool { if pc.DisableRenewal == nil { return pc.globalClaims.IsDisableRenewal() } @@ -69,7 +65,7 @@ func (pc *ProvisionerClaims) IsDisableRenewal() bool { } // Validate validates and modifies the Claims with default values. -func (pc *ProvisionerClaims) Validate() error { +func (pc *Claims) Validate() error { var ( min = pc.MinTLSCertDuration() max = pc.MaxTLSCertDuration() @@ -93,52 +89,3 @@ func (pc *ProvisionerClaims) Validate() error { return nil } } - -// Provisioner - authorized entity that can sign tokens necessary for signature requests. -type Provisioner struct { - Name string `json:"name,omitempty"` - Type string `json:"type,omitempty"` - Key *jose.JSONWebKey `json:"key,omitempty"` - EncryptedKey string `json:"encryptedKey,omitempty"` - Claims *ProvisionerClaims `json:"claims,omitempty"` -} - -// Init initializes and validates a the fields of Provisioner type. -func (p *Provisioner) Init(global *ProvisionerClaims) error { - switch { - case p.Name == "": - return errors.New("provisioner name cannot be empty") - - case p.Type == "": - return errors.New("provisioner type cannot be empty") - - case p.Key == nil: - return errors.New("provisioner key cannot be empty") - } - - var err error - p.Claims, err = p.Claims.Init(global) - return err -} - -// getTLSApps returns a list of modifiers and validators that will be applied to -// the certificate. -func (p *Provisioner) getTLSApps(so SignOptions) ([]x509util.WithOption, []certClaim, error) { - c := p.Claims - return []x509util.WithOption{ - x509util.WithNotBeforeAfterDuration(so.NotBefore, - so.NotAfter, c.DefaultTLSCertDuration()), - withProvisionerOID(p.Name, p.Key.KeyID), - }, []certClaim{ - &certTemporalClaim{ - min: c.MinTLSCertDuration(), - max: c.MaxTLSCertDuration(), - }, - }, nil -} - -// ID returns the provisioner identifier. The name and credential id should -// uniquely identify any provisioner. -func (p *Provisioner) ID() string { - return p.Name + ":" + p.Key.KeyID -} diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go new file mode 100644 index 00000000..8edd1965 --- /dev/null +++ b/authority/provisioner/collection.go @@ -0,0 +1,212 @@ +package provisioner + +import ( + "crypto/sha1" + "crypto/x509" + "encoding/asn1" + "encoding/binary" + "encoding/hex" + "fmt" + "net/url" + "sort" + "strings" + "sync" + + "github.com/pkg/errors" + "github.com/smallstep/cli/jose" +) + +// DefaultProvisionersLimit is the default limit for listing provisioners. +const DefaultProvisionersLimit = 20 + +// DefaultProvisionersMax is the maximum limit for listing provisioners. +const DefaultProvisionersMax = 100 + +type uidProvisioner struct { + provisioner Interface + uid string +} + +type provisionerSlice []uidProvisioner + +func (p provisionerSlice) Len() int { return len(p) } +func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } +func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// Collection is a memory map of provisioners. +type Collection struct { + byID *sync.Map + byKey *sync.Map + sorted provisionerSlice + audiences []string +} + +// NewCollection initializes a collection of provisioners. The given list of +// audiences are the audiences used by the JWT provisioner. +func NewCollection(audiences []string) *Collection { + return &Collection{ + byID: new(sync.Map), + byKey: new(sync.Map), + audiences: audiences, + } +} + +// Load a provisioner by the ID. +func (c *Collection) Load(id string) (Interface, bool) { + return loadProvisioner(c.byID, id) +} + +// LoadByToken parses the token claims and loads the provisioner associated. +func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) { + // match with server audiences + if matchesAudience(claims.Audience, c.audiences) { + // If matches with stored audiences it will be a JWT token (default), and + // the id would be :. + return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID) + } + + // The ID will be just the clientID stored in azp or aud. + var payload openIDPayload + if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil { + return nil, false + } + // audience is required + if len(payload.Audience) == 0 { + return nil, false + } + if len(payload.AuthorizedParty) > 0 { + return c.Load(payload.AuthorizedParty) + } + return c.Load(payload.Audience[0]) +} + +// LoadByCertificate looks for the provisioner extension and extracts the +// proper id to load the provisioner. +func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) { + for _, e := range cert.Extensions { + if e.Id.Equal(stepOIDProvisioner) { + var provisioner stepProvisionerASN1 + if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { + return nil, false + } + if provisioner.Type == int(TypeJWK) { + return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID)) + } + return c.Load(string(provisioner.CredentialID)) + } + } + + // Default to noop provisioner if an extension is not found. This allows to + // accept a renewal of a cert without the provisioner extension. + return &noop{}, true +} + +// LoadEncryptedKey returns an encrypted key by indexed by KeyID. At this moment +// only JWK encrypted keys are indexed by KeyID. +func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) { + p, ok := loadProvisioner(c.byKey, keyID) + if !ok { + return "", false + } + _, key, ok := p.GetEncryptedKey() + return key, ok +} + +// Store adds a provisioner to the collection and enforces the uniqueness of +// provisioner IDs. +func (c *Collection) Store(p Interface) error { + // Store provisioner always in byID. ID must be unique. + if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true { + return errors.New("cannot add multiple provisioners with the same id") + } + + // Store provisioner in byKey if EncryptedKey is defined. + if kid, _, ok := p.GetEncryptedKey(); ok { + c.byKey.Store(kid, p) + } + + // Store sorted provisioners. + // Use the first 4 bytes (32bit) of the sum to insert the order + // Using big endian format to get the strings sorted: + // 0x00000000, 0x00000001, 0x00000002, ... + bi := make([]byte, 4) + sum := provisionerSum(p) + binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len())) + sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3] + c.sorted = append(c.sorted, uidProvisioner{ + provisioner: p, + uid: hex.EncodeToString(sum), + }) + sort.Sort(c.sorted) + return nil +} + +// Find implements pagination on a list of sorted provisioners. +func (c *Collection) Find(cursor string, limit int) (List, string) { + switch { + case limit <= 0: + limit = DefaultProvisionersLimit + case limit > DefaultProvisionersMax: + limit = DefaultProvisionersMax + } + + n := c.sorted.Len() + cursor = fmt.Sprintf("%040s", cursor) + i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor }) + + slice := List{} + for ; i < n && len(slice) < limit; i++ { + slice = append(slice, c.sorted[i].provisioner) + } + + if i < n { + return slice, strings.TrimLeft(c.sorted[i].uid, "0") + } + return slice, "" +} + +func loadProvisioner(m *sync.Map, key string) (Interface, bool) { + i, ok := m.Load(key) + if !ok { + return nil, false + } + p, ok := i.(Interface) + if !ok { + return nil, false + } + return p, true +} + +// provisionerSum returns the SHA1 of the provisioners ID. From this we will +// create the unique and sorted id. +func provisionerSum(p Interface) []byte { + sum := sha1.Sum([]byte(p.GetID())) + return sum[:] +} + +// matchesAudience returns true if A and B share at least one element. +func matchesAudience(as, bs []string) bool { + if len(bs) == 0 || len(as) == 0 { + return false + } + + for _, b := range bs { + for _, a := range as { + if b == a || stripPort(a) == stripPort(b) { + return true + } + } + } + return false +} + +// stripPort attempts to strip the port from the given url. If parsing the url +// produces errors it will just return the passed argument. +func stripPort(rawurl string) string { + u, err := url.Parse(rawurl) + if err != nil { + return rawurl + } + u.Host = u.Hostname() + return u.String() +} diff --git a/authority/provisioner/collection_test.go b/authority/provisioner/collection_test.go new file mode 100644 index 00000000..d065d5f3 --- /dev/null +++ b/authority/provisioner/collection_test.go @@ -0,0 +1,390 @@ +package provisioner + +import ( + "crypto/x509" + "crypto/x509/pkix" + "reflect" + "strings" + "sync" + "testing" + + "github.com/smallstep/assert" + "github.com/smallstep/cli/jose" +) + +func TestCollection_Load(t *testing.T) { + p, err := generateJWK() + assert.FatalError(t, err) + byID := new(sync.Map) + byID.Store(p.GetID(), p) + byID.Store("string", "a-string") + + type fields struct { + byID *sync.Map + } + type args struct { + id string + } + tests := []struct { + name string + fields fields + args args + want Interface + want1 bool + }{ + {"ok", fields{byID}, args{p.GetID()}, p, true}, + {"fail", fields{byID}, args{"fail"}, nil, false}, + {"invalid", fields{byID}, args{"string"}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Collection{ + byID: tt.fields.byID, + } + got, got1 := c.Load(tt.args.id) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection.Load() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestCollection_LoadByToken(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateJWK() + assert.FatalError(t, err) + p3, err := generateOIDC() + assert.FatalError(t, err) + + byID := new(sync.Map) + byID.Store(p1.GetID(), p1) + byID.Store(p2.GetID(), p2) + byID.Store(p3.GetID(), p3) + byID.Store("string", "a-string") + + jwk, err := decryptJSONWebKey(p1.EncryptedKey) + assert.FatalError(t, err) + token, err := generateSimpleToken(p1.Name, testAudiences[0], jwk) + assert.FatalError(t, err) + t1, c1, err := parseToken(token) + assert.FatalError(t, err) + + jwk, err = decryptJSONWebKey(p2.EncryptedKey) + assert.FatalError(t, err) + token, err = generateSimpleToken(p2.Name, testAudiences[1], jwk) + assert.FatalError(t, err) + t2, c2, err := parseToken(token) + assert.FatalError(t, err) + + token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + t3, c3, err := parseToken(token) + assert.FatalError(t, err) + + token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + t4, c4, err := parseToken(token) + assert.FatalError(t, err) + + type fields struct { + byID *sync.Map + audiences []string + } + type args struct { + token *jose.JSONWebToken + claims *jose.Claims + } + tests := []struct { + name string + fields fields + args args + want Interface + want1 bool + }{ + {"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true}, + {"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true}, + {"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true}, + {"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false}, + {"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Collection{ + byID: tt.fields.byID, + audiences: tt.fields.audiences, + } + got, got1 := c.LoadByToken(tt.args.token, tt.args.claims) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection.LoadByToken() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.LoadByToken() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestCollection_LoadByCertificate(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateOIDC() + assert.FatalError(t, err) + + byID := new(sync.Map) + byID.Store(p1.GetID(), p1) + byID.Store(p2.GetID(), p2) + + ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID) + assert.FatalError(t, err) + ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID) + assert.FatalError(t, err) + notFoundExt, err := createProvisionerExtension(1, "foo", "bar") + assert.FatalError(t, err) + + ok1Cert := &x509.Certificate{ + Extensions: []pkix.Extension{ok1Ext}, + } + ok2Cert := &x509.Certificate{ + Extensions: []pkix.Extension{ok2Ext}, + } + notFoundCert := &x509.Certificate{ + Extensions: []pkix.Extension{notFoundExt}, + } + badCert := &x509.Certificate{ + Extensions: []pkix.Extension{ + {Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")}, + }, + } + + type fields struct { + byID *sync.Map + audiences []string + } + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + want Interface + want1 bool + }{ + {"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true}, + {"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true}, + {"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true}, + {"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false}, + {"badCert", fields{byID, testAudiences}, args{badCert}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Collection{ + byID: tt.fields.byID, + audiences: tt.fields.audiences, + } + got, got1 := c.LoadByCertificate(tt.args.cert) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestCollection_LoadEncryptedKey(t *testing.T) { + c := NewCollection(testAudiences) + p1, err := generateJWK() + assert.FatalError(t, err) + assert.FatalError(t, c.Store(p1)) + p2, err := generateOIDC() + assert.FatalError(t, err) + assert.FatalError(t, c.Store(p2)) + + // Add oidc in byKey. + // It should not happen. + p2KeyID := p2.keyStore.keySet.Keys[0].KeyID + c.byKey.Store(p2KeyID, p2) + + type args struct { + keyID string + } + tests := []struct { + name string + args args + want string + want1 bool + }{ + {"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true}, + {"oidc", args{p2KeyID}, "", false}, + {"notFound", args{"not-found"}, "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := c.LoadEncryptedKey(tt.args.keyID) + if got != tt.want { + t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestCollection_Store(t *testing.T) { + c := NewCollection(testAudiences) + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateOIDC() + assert.FatalError(t, err) + + type args struct { + p Interface + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok1", args{p1}, false}, + {"ok2", args{p2}, false}, + {"fail1", args{p1}, true}, + {"fail2", args{p2}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := c.Store(tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCollection_Find(t *testing.T) { + c, err := generateCollection(10, 10) + assert.FatalError(t, err) + + trim := func(s string) string { + return strings.TrimLeft(s, "0") + } + toList := func(ps provisionerSlice) List { + l := List{} + for _, p := range ps { + l = append(l, p.provisioner) + } + return l + } + + type args struct { + cursor string + limit int + } + tests := []struct { + name string + args args + want List + want1 string + }{ + {"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""}, + {"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""}, + {"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)}, + {"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""}, + {"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)}, + {"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)}, + {"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""}, + {"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := c.Find(tt.args.cursor, tt.args.limit) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection.Find() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func Test_matchesAudience(t *testing.T) { + type matchesTest struct { + a, b []string + exp bool + } + tests := map[string]matchesTest{ + "false arg1 empty": { + a: []string{}, + b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, + exp: false, + }, + "false arg2 empty": { + a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, + b: []string{}, + exp: false, + }, + "false arg1,arg2 empty": { + a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, + b: []string{"step-gateway", "step-cli"}, + exp: false, + }, + "false": { + a: []string{"step-gateway", "step-cli"}, + b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, + exp: false, + }, + "true": { + a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"}, + b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, + exp: true, + }, + "true,portsA": { + a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"}, + b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, + exp: true, + }, + "true,portsB": { + a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"}, + b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"}, + exp: true, + }, + "true,portsAB": { + a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"}, + b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"}, + exp: true, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b)) + }) + } +} + +func Test_stripPort(t *testing.T) { + type args struct { + rawurl string + } + tests := []struct { + name string + args args + want string + }{ + {"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"}, + {"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"}, + {"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := stripPort(tt.args.rawurl); got != tt.want { + t.Errorf("stripPort() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/duration.go b/authority/provisioner/duration.go new file mode 100644 index 00000000..38d504a3 --- /dev/null +++ b/authority/provisioner/duration.go @@ -0,0 +1,45 @@ +package provisioner + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal. +type Duration struct { + time.Duration +} + +// MarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.Duration.String()) +} + +// UnmarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *Duration) UnmarshalJSON(data []byte) (err error) { + var ( + s string + _d time.Duration + ) + if d == nil { + return errors.New("duration cannot be nil") + } + if err = json.Unmarshal(data, &s); err != nil { + return errors.Wrapf(err, "error unmarshaling %s", data) + } + if _d, err = time.ParseDuration(s); err != nil { + return errors.Wrapf(err, "error parsing %s as duration", s) + } + d.Duration = _d + return +} diff --git a/authority/provisioner/duration_test.go b/authority/provisioner/duration_test.go new file mode 100644 index 00000000..4f7304a0 --- /dev/null +++ b/authority/provisioner/duration_test.go @@ -0,0 +1,61 @@ +package provisioner + +import ( + "reflect" + "testing" + "time" +) + +func TestDuration_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + d *Duration + args args + want *Duration + wantErr bool + }{ + {"empty", new(Duration), args{[]byte{}}, new(Duration), true}, + {"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true}, + {"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true}, + {"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true}, + {"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false}, + {"nil", nil, args{nil}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(tt.d, tt.want) { + t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want) + } + }) + } +} + +func TestDuration_MarshalJSON(t *testing.T) { + tests := []struct { + name string + d *Duration + want []byte + wantErr bool + }{ + {"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.d.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go new file mode 100644 index 00000000..4575ca47 --- /dev/null +++ b/authority/provisioner/jwk.go @@ -0,0 +1,125 @@ +package provisioner + +import ( + "crypto/x509" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/cli/crypto/x509util" + "github.com/smallstep/cli/jose" +) + +// jwtPayload extends jwt.Claims with step attributes. +type jwtPayload struct { + jose.Claims + SANs []string `json:"sans,omitempty"` +} + +// JWK is the default provisioner, an entity that can sign tokens necessary for +// signature requests. +type JWK struct { + Type string `json:"type"` + Name string `json:"name"` + Key *jose.JSONWebKey `json:"key"` + EncryptedKey string `json:"encryptedKey,omitempty"` + Claims *Claims `json:"claims,omitempty"` + audiences []string +} + +// GetID returns the provisioner unique identifier. The name and credential id +// should uniquely identify any JWK provisioner. +func (p *JWK) GetID() string { + return p.Name + ":" + p.Key.KeyID +} + +// GetName returns the name of the provisioner. +func (p *JWK) GetName() string { + return p.Name +} + +// GetType returns the type of provisioner. +func (p *JWK) GetType() Type { + return TypeJWK +} + +// GetEncryptedKey returns the base provisioner encrypted key if it's defined. +func (p *JWK) GetEncryptedKey() (string, string, bool) { + return p.Key.KeyID, p.EncryptedKey, len(p.EncryptedKey) > 0 +} + +// Init initializes and validates the fields of a JWK type. +func (p *JWK) Init(config Config) (err error) { + switch { + case p.Type == "": + return errors.New("provisioner type cannot be empty") + case p.Name == "": + return errors.New("provisioner name cannot be empty") + case p.Key == nil: + return errors.New("provisioner key cannot be empty") + } + p.Claims, err = p.Claims.Init(&config.Claims) + p.audiences = config.Audiences + return err +} + +// Authorize validates the given token. +func (p *JWK) Authorize(token string) ([]SignOption, error) { + jwt, err := jose.ParseSigned(token) + if err != nil { + return nil, errors.Wrapf(err, "error parsing token") + } + + var claims jwtPayload + if err = jwt.Claims(p.Key, &claims); err != nil { + return nil, errors.Wrap(err, "error parsing claims") + } + + // According to "rfc7519 JSON Web Token" acceptable skew should be no + // more than a few minutes. + if err = claims.ValidateWithLeeway(jose.Expected{ + Issuer: p.Name, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + return nil, errors.Wrapf(err, "invalid token") + } + + // validate audiences with the defaults + if !matchesAudience(claims.Audience, p.audiences) { + return nil, errors.New("invalid token: invalid audience claim (aud)") + } + + if claims.Subject == "" { + return nil, errors.New("token subject cannot be empty") + } + + // NOTE: This is for backwards compatibility with older versions of cli + // and certificates. Older versions added the token subject as the only SAN + // in a CSR by default. + if len(claims.SANs) == 0 { + claims.SANs = []string{claims.Subject} + } + + dnsNames, ips := x509util.SplitSANs(claims.SANs) + return []SignOption{ + commonNameValidator(claims.Subject), + dnsNamesValidator(dnsNames), + ipAddressesValidator(ips), + profileDefaultDuration(p.Claims.DefaultTLSCertDuration()), + newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), + newValidityValidator(p.Claims.MinTLSCertDuration(), p.Claims.MaxTLSCertDuration()), + }, nil +} + +// AuthorizeRenewal returns an error if the renewal is disabled. +func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error { + if p.Claims.IsDisableRenewal() { + return errors.Errorf("renew is disabled for provisioner %s", p.GetID()) + } + return nil +} + +// AuthorizeRevoke returns an error if the provisioner does not have rights to +// revoke the certificate with serial number in the `sub` property. +func (p *JWK) AuthorizeRevoke(token string) error { + return errors.New("not implemented") +} diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go new file mode 100644 index 00000000..53260313 --- /dev/null +++ b/authority/provisioner/jwk_test.go @@ -0,0 +1,256 @@ +package provisioner + +import ( + "crypto/x509" + "errors" + "strings" + "testing" + "time" + + "github.com/smallstep/assert" + "github.com/smallstep/cli/jose" +) + +var ( + defaultDisableRenewal = false + globalProvisionerClaims = Claims{ + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + } +) + +func TestJWK_Getters(t *testing.T) { + p, err := generateJWK() + assert.FatalError(t, err) + if got := p.GetID(); got != p.Name+":"+p.Key.KeyID { + t.Errorf("JWK.GetID() = %v, want %v:%v", got, p.Name, p.Key.KeyID) + } + if got := p.GetName(); got != p.Name { + t.Errorf("JWK.GetName() = %v, want %v", got, p.Name) + } + if got := p.GetType(); got != TypeJWK { + t.Errorf("JWK.GetType() = %v, want %v", got, TypeJWK) + } + kid, key, ok := p.GetEncryptedKey() + if kid != p.Key.KeyID || key != p.EncryptedKey || ok == false { + t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", + kid, key, ok, p.Key.KeyID, p.EncryptedKey, true) + } + p.EncryptedKey = "" + kid, key, ok = p.GetEncryptedKey() + if kid != p.Key.KeyID || key != "" || ok == true { + t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", + kid, key, ok, p.Key.KeyID, "", false) + } +} + +func TestJWK_Init(t *testing.T) { + type ProvisionerValidateTest struct { + p *JWK + err error + } + tests := map[string]func(*testing.T) ProvisionerValidateTest{ + "fail-empty": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &JWK{}, + err: errors.New("provisioner type cannot be empty"), + } + }, + "fail-empty-name": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &JWK{ + Type: "JWK", + }, + err: errors.New("provisioner name cannot be empty"), + } + }, + "fail-empty-type": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &JWK{Name: "foo"}, + err: errors.New("provisioner type cannot be empty"), + } + }, + "fail-empty-key": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &JWK{Name: "foo", Type: "bar"}, + err: errors.New("provisioner key cannot be empty"), + } + }, + "ok": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences}, + } + }, + } + + config := Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + } + for name, get := range tests { + t.Run(name, func(t *testing.T) { + tc := get(t) + err := tc.p.Init(config) + if err != nil { + if assert.NotNil(t, tc.err) { + assert.Equals(t, tc.err.Error(), err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestJWK_Authorize(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateJWK() + assert.FatalError(t, err) + + key1, err := decryptJSONWebKey(p1.EncryptedKey) + assert.FatalError(t, err) + key2, err := decryptJSONWebKey(p2.EncryptedKey) + assert.FatalError(t, err) + + t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1) + assert.FatalError(t, err) + t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2) + assert.FatalError(t, err) + t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], "", []string{}, time.Now(), key1) + assert.FatalError(t, err) + + // Invalid tokens + parts := strings.Split(t1, ".") + key3, err := generateJSONWebKey() + assert.FatalError(t, err) + // missing key + failKey, err := generateSimpleToken(p1.Name, testAudiences[0], key3) + assert.FatalError(t, err) + // invalid token + failTok := "foo." + parts[1] + "." + parts[2] + // invalid claims + failClaims := parts[0] + ".foo." + parts[1] + // invalid issuer + failIss, err := generateSimpleToken("foobar", testAudiences[0], key1) + assert.FatalError(t, err) + // invalid audience + failAud, err := generateSimpleToken(p1.Name, "foobar", key1) + assert.FatalError(t, err) + // invalid signature + failSig := t1[0 : len(t1)-2] + // no subject + failSub, err := generateToken("", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now(), key1) + assert.FatalError(t, err) + // expired + failExp, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1) + assert.FatalError(t, err) + // not before + failNbf, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1) + assert.FatalError(t, err) + + // Remove encrypted key for p2 + p2.EncryptedKey = "" + + type args struct { + token string + } + tests := []struct { + name string + prov *JWK + args args + wantErr bool + }{ + {"ok", p1, args{t1}, false}, + {"ok-no-encrypted-key", p2, args{t2}, false}, + {"ok-no-sans", p1, args{t3}, false}, + {"fail-key", p1, args{failKey}, true}, + {"fail-token", p1, args{failTok}, true}, + {"fail-claims", p1, args{failClaims}, true}, + {"fail-issuer", p1, args{failIss}, true}, + {"fail-audience", p1, args{failAud}, true}, + {"fail-signature", p1, args{failSig}, true}, + {"fail-subject", p1, args{failSub}, true}, + {"fail-expired", p1, args{failExp}, true}, + {"fail-not-before", p1, args{failNbf}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.prov.Authorize(tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("JWK.Authorize() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + assert.Nil(t, got) + } else { + assert.NotNil(t, got) + assert.Len(t, 6, got) + } + }) + } +} + +func TestJWK_AuthorizeRenewal(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateJWK() + assert.FatalError(t, err) + + // disable renewal + disable := true + p2.Claims = &Claims{ + globalClaims: &globalProvisionerClaims, + DisableRenewal: &disable, + } + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + prov *JWK + args args + wantErr bool + }{ + {"ok", p1, args{nil}, false}, + {"fail", p2, args{nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("JWK.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestJWK_AuthorizeRevoke(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + key1, err := decryptJSONWebKey(p1.EncryptedKey) + assert.FatalError(t, err) + t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1) + assert.FatalError(t, err) + + type args struct { + token string + } + tests := []struct { + name string + prov *JWK + args args + wantErr bool + }{ + {"disabled", p1, args{t1}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr { + t.Errorf("JWK.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/provisioner/keystore.go b/authority/provisioner/keystore.go new file mode 100644 index 00000000..2f11114a --- /dev/null +++ b/authority/provisioner/keystore.go @@ -0,0 +1,135 @@ +package provisioner + +import ( + "encoding/json" + "math/rand" + "net/http" + "regexp" + "strconv" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/cli/jose" +) + +const ( + defaultCacheAge = 12 * time.Hour + defaultCacheJitter = 1 * time.Hour +) + +var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)") + +type keyStore struct { + sync.RWMutex + uri string + keySet jose.JSONWebKeySet + timer *time.Timer + expiry time.Time + jitter time.Duration +} + +func newKeyStore(uri string) (*keyStore, error) { + keys, age, err := getKeysFromJWKsURI(uri) + if err != nil { + return nil, err + } + ks := &keyStore{ + uri: uri, + keySet: keys, + expiry: getExpirationTime(age), + jitter: getCacheJitter(age), + } + next := ks.nextReloadDuration(age) + ks.timer = time.AfterFunc(next, ks.reload) + return ks, nil +} + +func (ks *keyStore) Close() { + ks.timer.Stop() +} + +func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { + ks.RLock() + // Force reload if expiration has passed + if time.Now().After(ks.expiry) { + ks.RUnlock() + ks.reload() + ks.RLock() + } + keys = ks.keySet.Key(kid) + ks.RUnlock() + return +} + +func (ks *keyStore) reload() { + var next time.Duration + keys, age, err := getKeysFromJWKsURI(ks.uri) + if err != nil { + next = ks.nextReloadDuration(ks.jitter / 2) + } else { + ks.Lock() + ks.keySet = keys + ks.expiry = getExpirationTime(age) + ks.jitter = getCacheJitter(age) + next = ks.nextReloadDuration(age) + ks.Unlock() + } + + ks.Lock() + ks.timer.Reset(next) + ks.Unlock() +} + +func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { + n := rand.Int63n(int64(ks.jitter)) + age -= time.Duration(n) + if age < 0 { + age = 0 + } + return age +} + +func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) { + var keys jose.JSONWebKeySet + resp, err := http.Get(uri) + if err != nil { + return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri) + } + defer resp.Body.Close() + if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { + return keys, 0, errors.Wrapf(err, "error reading %s", uri) + } + return keys, getCacheAge(resp.Header.Get("cache-control")), nil +} + +func getCacheAge(cacheControl string) time.Duration { + age := defaultCacheAge + if len(cacheControl) > 0 { + match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1) + if len(match) > 0 { + if len(match[0]) == 2 { + maxAge := match[0][1] + maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64) + if err != nil { + return defaultCacheAge + } + age = time.Duration(maxAgeInt) * time.Second + } + } + } + return age +} + +func getCacheJitter(age time.Duration) time.Duration { + switch { + case age > time.Hour: + return defaultCacheJitter + default: + return age / 3 + } +} + +func getExpirationTime(age time.Duration) time.Time { + return time.Now().Truncate(time.Second).Add(age) +} diff --git a/authority/provisioner/keystore_test.go b/authority/provisioner/keystore_test.go new file mode 100644 index 00000000..22d5be75 --- /dev/null +++ b/authority/provisioner/keystore_test.go @@ -0,0 +1,121 @@ +package provisioner + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + "github.com/smallstep/assert" + "github.com/smallstep/cli/jose" +) + +func Test_newKeyStore(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + ks, err := newKeyStore(srv.URL) + assert.FatalError(t, err) + defer ks.Close() + + type args struct { + uri string + } + tests := []struct { + name string + args args + want jose.JSONWebKeySet + wantErr bool + }{ + {"ok", args{srv.URL}, ks.keySet, false}, + {"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newKeyStore(tt.args.uri) + if (err != nil) != tt.wantErr { + t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + if !reflect.DeepEqual(got.keySet, tt.want) { + t.Errorf("newKeyStore() = %v, want %v", got, tt.want) + } + got.Close() + } + }) + } +} + +func Test_keyStore(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + + ks, err := newKeyStore(srv.URL + "/random") + assert.FatalError(t, err) + defer ks.Close() + ks.RLock() + keySet1 := ks.keySet + ks.RUnlock() + // Check contents + assert.Len(t, 2, keySet1.Keys) + assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID)) + assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID)) + assert.Len(t, 0, ks.Get("foobar")) + + // Wait for rotation + time.Sleep(5 * time.Second) + + ks.RLock() + keySet2 := ks.keySet + ks.RUnlock() + if reflect.DeepEqual(keySet1, keySet2) { + t.Error("keyStore did not rotated") + } + + // Check contents + assert.Len(t, 2, keySet2.Keys) + assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID)) + assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID)) + assert.Len(t, 0, ks.Get("foobar")) + + // Check hits + resp, err := srv.Client().Get(srv.URL + "/hits") + assert.FatalError(t, err) + hits := struct { + Hits int `json:"hits"` + }{} + defer resp.Body.Close() + err = json.NewDecoder(resp.Body).Decode(&hits) + assert.FatalError(t, err) + assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits)) +} + +func Test_keyStore_Get(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + ks, err := newKeyStore(srv.URL) + assert.FatalError(t, err) + defer ks.Close() + + type args struct { + kid string + } + tests := []struct { + name string + ks *keyStore + args args + wantKeys []jose.JSONWebKey + }{ + {"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}}, + {"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}}, + {"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) { + t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys) + } + }) + } +} diff --git a/authority/provisioner/noop.go b/authority/provisioner/noop.go new file mode 100644 index 00000000..c00ba61f --- /dev/null +++ b/authority/provisioner/noop.go @@ -0,0 +1,37 @@ +package provisioner + +import "crypto/x509" + +// noop provisioners is a provisioner that accepts anything. +type noop struct{} + +func (p *noop) GetID() string { + return "noop" +} + +func (p *noop) GetName() string { + return "noop" +} +func (p *noop) GetType() Type { + return noopType +} + +func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) { + return "", "", false +} + +func (p *noop) Init(config Config) error { + return nil +} + +func (p *noop) Authorize(token string) ([]SignOption, error) { + return []SignOption{}, nil +} + +func (p *noop) AuthorizeRenewal(cert *x509.Certificate) error { + return nil +} + +func (p *noop) AuthorizeRevoke(token string) error { + return nil +} diff --git a/authority/provisioner/noop_test.go b/authority/provisioner/noop_test.go new file mode 100644 index 00000000..a548430e --- /dev/null +++ b/authority/provisioner/noop_test.go @@ -0,0 +1,27 @@ +package provisioner + +import ( + "crypto/x509" + "testing" + + "github.com/smallstep/assert" +) + +func Test_noop(t *testing.T) { + p := noop{} + assert.Equals(t, "noop", p.GetID()) + assert.Equals(t, "noop", p.GetName()) + assert.Equals(t, noopType, p.GetType()) + assert.Equals(t, nil, p.Init(Config{})) + assert.Equals(t, nil, p.AuthorizeRenewal(&x509.Certificate{})) + assert.Equals(t, nil, p.AuthorizeRevoke("foo")) + + kid, key, ok := p.GetEncryptedKey() + assert.Equals(t, "", kid) + assert.Equals(t, "", key) + assert.Equals(t, false, ok) + + sigOptions, err := p.Authorize("foo") + assert.Equals(t, []SignOption{}, sigOptions) + assert.Equals(t, nil, err) +} diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go new file mode 100644 index 00000000..831a6e8b --- /dev/null +++ b/authority/provisioner/oidc.go @@ -0,0 +1,243 @@ +package provisioner + +import ( + "crypto/x509" + "encoding/json" + "net/http" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/cli/jose" +) + +// openIDConfiguration contains the necessary properties in the +// `/.well-known/openid-configuration` document. +type openIDConfiguration struct { + Issuer string `json:"issuer"` + JWKSetURI string `json:"jwks_uri"` +} + +// Validate validates the values in a well-known OpenID configuration endpoint. +func (c openIDConfiguration) Validate() error { + switch { + case c.Issuer == "": + return errors.New("issuer cannot be empty") + case c.JWKSetURI == "": + return errors.New("jwks_uri cannot be empty") + default: + return nil + } +} + +// openIDPayload represents the fields on the id_token JWT payload. +type openIDPayload struct { + jose.Claims + AtHash string `json:"at_hash"` + AuthorizedParty string `json:"azp"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Hd string `json:"hd"` + Nonce string `json:"nonce"` +} + +// OIDC represents an OAuth 2.0 OpenID Connect provider. +// +// ClientSecret is mandatory, but it can be an empty string. +type OIDC struct { + Type string `json:"type"` + Name string `json:"name"` + ClientID string `json:"clientID"` + ClientSecret string `json:"clientSecret"` + ConfigurationEndpoint string `json:"configurationEndpoint"` + Admins []string `json:"admins,omitempty"` + Domains []string `json:"domains,omitempty"` + Claims *Claims `json:"claims,omitempty"` + configuration openIDConfiguration + keyStore *keyStore +} + +// IsAdmin returns true if the given email is in the Admins whitelist, false +// otherwise. +func (o *OIDC) IsAdmin(email string) bool { + email = sanitizeEmail(email) + for _, e := range o.Admins { + if email == sanitizeEmail(e) { + return true + } + } + return false +} + +func sanitizeEmail(email string) string { + if i := strings.LastIndex(email, "@"); i >= 0 { + email = email[:i] + strings.ToLower(email[i:]) + } + return email +} + +// GetID returns the provisioner unique identifier, the OIDC provisioner the +// uses the clientID for this. +func (o *OIDC) GetID() string { + return o.ClientID +} + +// GetName returns the name of the provisioner. +func (o *OIDC) GetName() string { + return o.Name +} + +// GetType returns the type of provisioner. +func (o *OIDC) GetType() Type { + return TypeOIDC +} + +// GetEncryptedKey is not available in an OIDC provisioner. +func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) { + return "", "", false +} + +// Init validates and initializes the OIDC provider. +func (o *OIDC) Init(config Config) (err error) { + switch { + case o.Type == "": + return errors.New("type cannot be empty") + case o.Name == "": + return errors.New("name cannot be empty") + case o.ClientID == "": + return errors.New("clientID cannot be empty") + case o.ConfigurationEndpoint == "": + return errors.New("configurationEndpoint cannot be empty") + } + + // Update claims with global ones + if o.Claims, err = o.Claims.Init(&config.Claims); err != nil { + return err + } + // Decode and validate openid-configuration endpoint + if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil { + return err + } + if err := o.configuration.Validate(); err != nil { + return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint) + } + // Get JWK key set + o.keyStore, err = newKeyStore(o.configuration.JWKSetURI) + if err != nil { + return err + } + return nil +} + +// ValidatePayload validates the given token payload. +func (o *OIDC) ValidatePayload(p openIDPayload) error { + // According to "rfc7519 JSON Web Token" acceptable skew should be no more + // than a few minutes. + if err := p.ValidateWithLeeway(jose.Expected{ + Issuer: o.configuration.Issuer, + Audience: jose.Audience{o.ClientID}, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + return errors.Wrap(err, "failed to validate payload") + } + + // Validate azp if present + if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID { + return errors.New("failed to validate payload: invalid azp") + } + + // Enforce an email claim + if p.Email == "" { + return errors.New("failed to validate payload: email not found") + } + + // Validate domains (case-insensitive) + if !o.IsAdmin(p.Email) && len(o.Domains) > 0 { + email := sanitizeEmail(p.Email) + var found bool + for _, d := range o.Domains { + if strings.HasSuffix(email, "@"+strings.ToLower(d)) { + found = true + break + } + } + if !found { + return errors.New("failed to validate payload: email is not allowed") + } + } + + return nil +} + +// Authorize validates the given token. +func (o *OIDC) Authorize(token string) ([]SignOption, error) { + jwt, err := jose.ParseSigned(token) + if err != nil { + return nil, errors.Wrapf(err, "error parsing token") + } + + // Parse claims to get the kid + var claims openIDPayload + if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, errors.Wrap(err, "error parsing claims") + } + + found := false + kid := jwt.Headers[0].KeyID + keys := o.keyStore.Get(kid) + for _, key := range keys { + if err := jwt.Claims(key, &claims); err == nil { + found = true + break + } + } + if !found { + return nil, errors.New("cannot validate token") + } + + if err := o.ValidatePayload(claims); err != nil { + return nil, err + } + + // Admins should be able to authorize any SAN + if o.IsAdmin(claims.Email) { + return []SignOption{ + profileDefaultDuration(o.Claims.DefaultTLSCertDuration()), + newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), + newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()), + }, nil + } + + return []SignOption{ + emailOnlyIdentity(claims.Email), + profileDefaultDuration(o.Claims.DefaultTLSCertDuration()), + newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), + newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()), + }, nil +} + +// AuthorizeRenewal returns an error if the renewal is disabled. +func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error { + if o.Claims.IsDisableRenewal() { + return errors.Errorf("renew is disabled for provisioner %s", o.GetID()) + } + return nil +} + +// AuthorizeRevoke returns an error if the provisioner does not have rights to +// revoke the certificate with serial number in the `sub` property. +func (o *OIDC) AuthorizeRevoke(token string) error { + return errors.New("not implemented") +} + +func getAndDecode(uri string, v interface{}) error { + resp, err := http.Get(uri) + if err != nil { + return errors.Wrapf(err, "failed to connect to %s", uri) + } + defer resp.Body.Close() + if err := json.NewDecoder(resp.Body).Decode(v); err != nil { + return errors.Wrapf(err, "error reading %s", uri) + } + return nil +} diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go new file mode 100644 index 00000000..eddf27fa --- /dev/null +++ b/authority/provisioner/oidc_test.go @@ -0,0 +1,327 @@ +package provisioner + +import ( + "crypto/x509" + "fmt" + "strings" + "testing" + "time" + + "github.com/smallstep/assert" + "github.com/smallstep/cli/jose" +) + +func Test_openIDConfiguration_Validate(t *testing.T) { + type fields struct { + Issuer string + JWKSetURI string + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"ok", fields{"the-issuer", "the-jwks-uri"}, false}, + {"no-issuer", fields{"", "the-jwks-uri"}, true}, + {"no-jwks-uri", fields{"the-issuer", ""}, true}, + {"empty", fields{"", ""}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := openIDConfiguration{ + Issuer: tt.fields.Issuer, + JWKSetURI: tt.fields.JWKSetURI, + } + if err := c.Validate(); (err != nil) != tt.wantErr { + t.Errorf("openIDConfiguration.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestOIDC_Getters(t *testing.T) { + p, err := generateOIDC() + assert.FatalError(t, err) + if got := p.GetID(); got != p.ClientID { + t.Errorf("OIDC.GetID() = %v, want %v", got, p.ClientID) + } + if got := p.GetName(); got != p.Name { + t.Errorf("OIDC.GetName() = %v, want %v", got, p.Name) + } + if got := p.GetType(); got != TypeOIDC { + t.Errorf("OIDC.GetType() = %v, want %v", got, TypeOIDC) + } + kid, key, ok := p.GetEncryptedKey() + if kid != "" || key != "" || ok == true { + t.Errorf("OIDC.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", + kid, key, ok, "", "", false) + } +} + +func TestOIDC_Init(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + config := Config{ + Claims: globalProvisionerClaims, + } + + type fields struct { + Type string + Name string + ClientID string + ClientSecret string + ConfigurationEndpoint string + Claims *Claims + Admins []string + Domains []string + } + type args struct { + config Config + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false}, + {"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}, nil}, args{config}, false}, + {"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, []string{"smallstep.com"}}, args{config}, false}, + {"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false}, + {"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true}, + {"no-type", fields{"", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true}, + {"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true}, + {"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil}, args{config}, true}, + {"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil}, args{config}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &OIDC{ + Type: tt.fields.Type, + Name: tt.fields.Name, + ClientID: tt.fields.ClientID, + ConfigurationEndpoint: tt.fields.ConfigurationEndpoint, + Claims: tt.fields.Claims, + Admins: tt.fields.Admins, + } + if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { + t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr == false { + assert.Len(t, 2, p.keyStore.keySet.Keys) + assert.Equals(t, openIDConfiguration{ + Issuer: "the-issuer", + JWKSetURI: srv.URL + "/jwks_uri", + }, p.configuration) + } + }) + } +} + +func TestOIDC_Authorize(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + + var keys jose.JSONWebKeySet + assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys)) + + // Create test provisioners + p1, err := generateOIDC() + assert.FatalError(t, err) + p2, err := generateOIDC() + assert.FatalError(t, err) + p3, err := generateOIDC() + assert.FatalError(t, err) + // Admin + Domains + p3.Admins = []string{"name@smallstep.com", "root@example.com"} + p3.Domains = []string{"smallstep.com"} + + // Update configuration endpoints and initialize + config := Config{Claims: globalProvisionerClaims} + p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + assert.FatalError(t, p1.Init(config)) + assert.FatalError(t, p2.Init(config)) + assert.FatalError(t, p3.Init(config)) + + t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) + assert.FatalError(t, err) + t2, err := generateSimpleToken("the-issuer", p2.ClientID, &keys.Keys[1]) + assert.FatalError(t, err) + t3, err := generateSimpleToken("the-issuer", p3.ClientID, &keys.Keys[0]) + assert.FatalError(t, err) + + // Admin email not in domains + okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) + assert.FatalError(t, err) + // Invalid email + failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) + assert.FatalError(t, err) + failDomain, err := generateToken("subject", "the-issuer", p3.ClientID, "name@example.com", []string{}, time.Now(), &keys.Keys[0]) + assert.FatalError(t, err) + + // Invalid tokens + parts := strings.Split(t1, ".") + key, err := generateJSONWebKey() + assert.FatalError(t, err) + // missing key + failKey, err := generateSimpleToken("the-issuer", p1.ClientID, key) + assert.FatalError(t, err) + // invalid token + failTok := "foo." + parts[1] + "." + parts[2] + // invalid claims + failClaims := parts[0] + ".foo." + parts[1] + // invalid issuer + failIss, err := generateSimpleToken("bad-issuer", p1.ClientID, &keys.Keys[0]) + assert.FatalError(t, err) + // invalid audience + failAud, err := generateSimpleToken("the-issuer", "foobar", &keys.Keys[0]) + assert.FatalError(t, err) + // invalid signature + failSig := t1[0 : len(t1)-2] + // expired + failExp, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(-360*time.Second), &keys.Keys[0]) + assert.FatalError(t, err) + // not before + failNbf, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(360*time.Second), &keys.Keys[0]) + assert.FatalError(t, err) + + type args struct { + token string + } + tests := []struct { + name string + prov *OIDC + args args + wantErr bool + }{ + {"ok1", p1, args{t1}, false}, + {"ok2", p2, args{t2}, false}, + {"admin", p3, args{t3}, false}, + {"admin", p3, args{okAdmin}, false}, + {"fail-email", p3, args{failEmail}, true}, + {"fail-domain", p3, args{failDomain}, true}, + {"fail-key", p1, args{failKey}, true}, + {"fail-token", p1, args{failTok}, true}, + {"fail-claims", p1, args{failClaims}, true}, + {"fail-issuer", p1, args{failIss}, true}, + {"fail-audience", p1, args{failAud}, true}, + {"fail-signature", p1, args{failSig}, true}, + {"fail-expired", p1, args{failExp}, true}, + {"fail-not-before", p1, args{failNbf}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.prov.Authorize(tt.args.token) + if (err != nil) != tt.wantErr { + fmt.Println(tt) + t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + assert.Nil(t, got) + } else { + assert.NotNil(t, got) + if tt.name == "admin" { + assert.Len(t, 3, got) + } else { + assert.Len(t, 4, got) + } + } + }) + } +} + +func TestOIDC_AuthorizeRenewal(t *testing.T) { + p1, err := generateOIDC() + assert.FatalError(t, err) + p2, err := generateOIDC() + assert.FatalError(t, err) + + // disable renewal + disable := true + p2.Claims = &Claims{ + globalClaims: &globalProvisionerClaims, + DisableRenewal: &disable, + } + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + prov *OIDC + args args + wantErr bool + }{ + {"ok", p1, args{nil}, false}, + {"fail", p2, args{nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("OIDC.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestOIDC_AuthorizeRevoke(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + + var keys jose.JSONWebKeySet + assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys)) + + // Create test provisioners + p1, err := generateOIDC() + assert.FatalError(t, err) + + // Update configuration endpoints and initialize + config := Config{Claims: globalProvisionerClaims} + p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + assert.FatalError(t, p1.Init(config)) + + t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) + assert.FatalError(t, err) + + type args struct { + token string + } + tests := []struct { + name string + prov *OIDC + args args + wantErr bool + }{ + {"disabled", p1, args{t1}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr { + t.Errorf("OIDC.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_sanitizeEmail(t *testing.T) { + tests := []struct { + name string + email string + want string + }{ + {"equal", "name@smallstep.com", "name@smallstep.com"}, + {"domain-insensitive", "name@SMALLSTEP.COM", "name@smallstep.com"}, + {"local-sensitive", "NaMe@smallSTEP.CoM", "NaMe@smallstep.com"}, + {"multiple-@", "NaMe@NaMe@smallSTEP.CoM", "NaMe@NaMe@smallstep.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := sanitizeEmail(tt.email); got != tt.want { + t.Errorf("sanitizeEmail() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go new file mode 100644 index 00000000..aff1c54c --- /dev/null +++ b/authority/provisioner/provisioner.go @@ -0,0 +1,82 @@ +package provisioner + +import ( + "crypto/x509" + "encoding/json" + "strings" + + "github.com/pkg/errors" +) + +// Interface is the interface that all provisioner types must implement. +type Interface interface { + GetID() string + GetName() string + GetType() Type + GetEncryptedKey() (kid string, key string, ok bool) + Init(config Config) error + Authorize(token string) ([]SignOption, error) + AuthorizeRenewal(cert *x509.Certificate) error + AuthorizeRevoke(token string) error +} + +// Type indicates the provisioner Type. +type Type int + +const ( + noopType Type = 0 + + // TypeJWK is used to indicate the JWK provisioners. + TypeJWK Type = 1 + + // TypeOIDC is used to indicate the OIDC provisioners. + TypeOIDC Type = 2 +) + +// Config defines the default parameters used in the initialization of +// provisioners. +type Config struct { + // Claims are the default claims. + Claims Claims + // Audiences are the audiences used in the default provisioner, (JWK). + Audiences []string +} + +type provisioner struct { + Type string `json:"type"` +} + +// List represents a list of provisioners. +type List []Interface + +// UnmarshalJSON implements json.Unmarshaler and allows to unmarshal a list of a +// interfaces into the right type. +func (l *List) UnmarshalJSON(data []byte) error { + ps := []json.RawMessage{} + if err := json.Unmarshal(data, &ps); err != nil { + return errors.Wrap(err, "error unmarshaling provisioner list") + } + + *l = List{} + for _, data := range ps { + var typ provisioner + if err := json.Unmarshal(data, &typ); err != nil { + return errors.Errorf("error unmarshaling provisioner") + } + var p Interface + switch strings.ToLower(typ.Type) { + case "jwk": + p = &JWK{} + case "oidc": + p = &OIDC{} + default: + return errors.Errorf("provisioner type %s not supported", typ.Type) + } + if err := json.Unmarshal(data, p); err != nil { + return errors.Errorf("error unmarshaling provisioner") + } + *l = append(*l, p) + } + + return nil +} diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go new file mode 100644 index 00000000..c28fd80b --- /dev/null +++ b/authority/provisioner/sign_options.go @@ -0,0 +1,233 @@ +package provisioner + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "net" + "reflect" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/cli/crypto/x509util" +) + +// Options contains the options that can be passed to the Sign method. +type Options struct { + NotAfter time.Time `json:"notAfter"` + NotBefore time.Time `json:"notBefore"` +} + +// SignOption is the interface used to collect all extra options used in the +// Sign method. +type SignOption interface{} + +// CertificateValidator is the interface used to validate a X.509 certificate. +type CertificateValidator interface { + SignOption + Valid(crt *x509.Certificate) error +} + +// CertificateRequestValidator is the interface used to validate a X.509 +// certificate request. +type CertificateRequestValidator interface { + SignOption + Valid(req *x509.CertificateRequest) error +} + +// ProfileModifier is the interface used to add custom options to the profile +// constructor. The options are used to modify the final certificate. +type ProfileModifier interface { + SignOption + Option(o Options) x509util.WithOption +} + +// profileWithOption is a wrapper against x509util.WithOption to conform the +// interface. +type profileWithOption x509util.WithOption + +func (v profileWithOption) Option(Options) x509util.WithOption { + return x509util.WithOption(v) +} + +// profileDefaultDuration is a wrapper against x509util.WithOption to conform the +// interface. +type profileDefaultDuration time.Duration + +func (v profileDefaultDuration) Option(so Options) x509util.WithOption { + return x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, time.Duration(v)) +} + +// emailOnlyIdentity is a CertificateRequestValidator that checks that the only +// SAN provided is the given email address. +type emailOnlyIdentity string + +func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error { + switch { + case len(req.DNSNames) > 0: + return errors.New("certificate request cannot contain DNS names") + case len(req.IPAddresses) > 0: + return errors.New("certificate request cannot contain IP addresses") + case len(req.URIs) > 0: + return errors.New("certificate request cannot contain URIs") + case len(req.EmailAddresses) == 0: + return errors.New("certificate request does not contain any email address") + case len(req.EmailAddresses) > 1: + return errors.New("certificate request does not contain too many email addresses") + case req.EmailAddresses[0] == "": + return errors.New("certificate request cannot contain an empty email address") + case req.EmailAddresses[0] != string(e): + return errors.Errorf("certificate request does not contain the valid email address, got %s, want %s", req.EmailAddresses[0], e) + default: + return nil + } +} + +// commonNameValidator validates the common name of a certificate request. +type commonNameValidator string + +// Valid checks that certificate request common name matches the one configured. +func (v commonNameValidator) Valid(req *x509.CertificateRequest) error { + if req.Subject.CommonName == "" { + return errors.New("certificate request cannot contain an empty common name") + } + if req.Subject.CommonName != string(v) { + return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v) + } + return nil +} + +// dnsNamesValidator validates the DNS names SAN of a certificate request. +type dnsNamesValidator []string + +// Valid checks that certificate request DNS Names match those configured in +// the bootstrap (token) flow. +func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error { + want := make(map[string]bool) + for _, s := range v { + want[s] = true + } + got := make(map[string]bool) + for _, s := range req.DNSNames { + got[s] = true + } + if !reflect.DeepEqual(want, got) { + return errors.Errorf("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v) + } + return nil +} + +// ipAddressesValidator validates the IP addresses SAN of a certificate request. +type ipAddressesValidator []net.IP + +// Valid checks that certificate request IP Addresses match those configured in +// the bootstrap (token) flow. +func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error { + want := make(map[string]bool) + for _, ip := range v { + want[ip.String()] = true + } + got := make(map[string]bool) + for _, ip := range req.IPAddresses { + got[ip.String()] = true + } + if !reflect.DeepEqual(want, got) { + return errors.Errorf("IP Addresses claim failed - got %v, want %v", req.IPAddresses, v) + } + return nil +} + +// validityValidator validates the certificate temporal validity settings. +type validityValidator struct { + min time.Duration + max time.Duration +} + +// newValidityValidator return a new validity validator. +func newValidityValidator(min, max time.Duration) *validityValidator { + return &validityValidator{min: min, max: max} +} + +// Validate validates the certificate temporal validity settings. +func (v *validityValidator) Valid(crt *x509.Certificate) error { + var ( + na = crt.NotAfter + nb = crt.NotBefore + d = na.Sub(nb) + now = time.Now() + ) + + if na.Before(now) { + return errors.Errorf("NotAfter: %v cannot be in the past", na) + } + if na.Before(nb) { + return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb) + } + if d < v.min { + return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v", + d, v.min) + } + if d > v.max { + return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v", + d, v.max) + } + return nil +} + +var ( + stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} + stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) +) + +type stepProvisionerASN1 struct { + Type int + Name []byte + CredentialID []byte +} + +type provisionerExtensionOption struct { + Type int + Name string + CredentialID string +} + +func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisionerExtensionOption { + return &provisionerExtensionOption{ + Type: int(typ), + Name: name, + CredentialID: credentialID, + } +} + +func (o *provisionerExtensionOption) Option(Options) x509util.WithOption { + return func(p x509util.Profile) error { + crt := p.Subject() + ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID) + if err != nil { + return err + } + crt.ExtraExtensions = append(crt.ExtraExtensions, ext) + return nil + } +} + +func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) { + b, err := asn1.Marshal(stepProvisionerASN1{ + Type: typ, + Name: []byte(name), + CredentialID: []byte(credentialID), + }) + if err != nil { + return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension") + } + return pkix.Extension{ + Id: stepOIDProvisioner, + Critical: false, + Value: b, + }, nil +} + +func init() { + // Avoid deadcode warning in profileWithOption + _ = profileWithOption(nil) +} diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go new file mode 100644 index 00000000..e1349974 --- /dev/null +++ b/authority/provisioner/sign_options_test.go @@ -0,0 +1,152 @@ +package provisioner + +import ( + "crypto/x509" + "crypto/x509/pkix" + "net" + "net/url" + "testing" + "time" +) + +func Test_emailOnlyIdentity_Valid(t *testing.T) { + uri, err := url.Parse("https://example.com/1.0/getUser") + if err != nil { + t.Fatal(err) + } + + type args struct { + req *x509.CertificateRequest + } + tests := []struct { + name string + e emailOnlyIdentity + args args + wantErr bool + }{ + {"ok", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com"}}}, false}, + {"DNSNames", "name@smallstep.com", args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, true}, + {"IPAddresses", "name@smallstep.com", args{&x509.CertificateRequest{IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}}}, true}, + {"URIs", "name@smallstep.com", args{&x509.CertificateRequest{URIs: []*url.URL{uri}}}, true}, + {"no-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{}}}, true}, + {"empty-email", "", args{&x509.CertificateRequest{EmailAddresses: []string{""}}}, true}, + {"multiple-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com", "foo@smallstep.com"}}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.e.Valid(tt.args.req); (err != nil) != tt.wantErr { + t.Errorf("emailOnlyIdentity.Valid() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_commonNameValidator_Valid(t *testing.T) { + type args struct { + req *x509.CertificateRequest + } + tests := []struct { + name string + v commonNameValidator + args args + wantErr bool + }{ + {"ok", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false}, + {"empty", "", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: ""}}}, true}, + {"wrong", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "example.com"}}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { + t.Errorf("commonNameValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_dnsNamesValidator_Valid(t *testing.T) { + type args struct { + req *x509.CertificateRequest + } + tests := []struct { + name string + v dnsNamesValidator + args args + wantErr bool + }{ + {"ok0", []string{}, args{&x509.CertificateRequest{DNSNames: []string{}}}, false}, + {"ok1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, false}, + {"ok2", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "bar.zar"}}}, false}, + {"ok3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, false}, + {"fail1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar"}}}, true}, + {"fail2", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, true}, + {"fail3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "zar.bar"}}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { + t.Errorf("dnsNamesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_ipAddressesValidator_Valid(t *testing.T) { + ip1 := net.IPv4(10, 3, 2, 1) + ip2 := net.IPv4(10, 3, 2, 2) + ip3 := net.IPv4(10, 3, 2, 3) + + type args struct { + req *x509.CertificateRequest + } + tests := []struct { + name string + v ipAddressesValidator + args args + wantErr bool + }{ + {"ok0", []net.IP{}, args{&x509.CertificateRequest{IPAddresses: []net.IP{}}}, false}, + {"ok1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1}}}, false}, + {"ok2", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip2}}}, false}, + {"ok3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, false}, + {"fail1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2}}}, true}, + {"fail2", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, true}, + {"fail3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip3}}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { + t.Errorf("ipAddressesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_validityValidator_Valid(t *testing.T) { + type fields struct { + min time.Duration + max time.Duration + } + type args struct { + crt *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := &validityValidator{ + min: tt.fields.min, + max: tt.fields.max, + } + if err := v.Valid(tt.args.crt); (err != nil) != tt.wantErr { + t.Errorf("validityValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/provisioner/testdata/root_ca.crt b/authority/provisioner/testdata/root_ca.crt new file mode 100644 index 00000000..c802b420 --- /dev/null +++ b/authority/provisioner/testdata/root_ca.crt @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf +MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla +Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg +Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN +Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw +QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU +B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c +ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET +/A8LXNH4M06A7vE= +-----END CERTIFICATE----- diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go new file mode 100644 index 00000000..e53b42d3 --- /dev/null +++ b/authority/provisioner/utils_test.go @@ -0,0 +1,272 @@ +package provisioner + +import ( + "crypto" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "time" + + "github.com/smallstep/cli/crypto/randutil" + "github.com/smallstep/cli/jose" +) + +var testAudiences = []string{ + "https://ca.smallstep.com/sign", + "https://ca.smallsteomcom/1.0/sign", +} + +func must(args ...interface{}) []interface{} { + if l := len(args); l > 0 && args[l-1] != nil { + if err, ok := args[l-1].(error); ok { + panic(err) + } + } + return args +} + +func generateJSONWebKey() (*jose.JSONWebKey, error) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + if err != nil { + return nil, err + } + fp, err := jwk.Thumbprint(crypto.SHA256) + if err != nil { + return nil, err + } + jwk.KeyID = string(hex.EncodeToString(fp)) + return jwk, nil +} + +func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) { + var keySet jose.JSONWebKeySet + for i := 0; i < n; i++ { + key, err := generateJSONWebKey() + if err != nil { + return jose.JSONWebKeySet{}, err + } + keySet.Keys = append(keySet.Keys, *key) + } + return keySet, nil +} + +func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) { + b, err := json.Marshal(jwk) + if err != nil { + return nil, err + } + salt, err := randutil.Salt(jose.PBKDF2SaltSize) + if err != nil { + return nil, err + } + opts := new(jose.EncrypterOptions) + opts.WithContentType(jose.ContentType("jwk+json")) + recipient := jose.Recipient{ + Algorithm: jose.PBES2_HS256_A128KW, + Key: []byte("password"), + PBES2Count: jose.PBKDF2Iterations, + PBES2Salt: salt, + } + encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts) + if err != nil { + return nil, err + } + return encrypter.Encrypt(b) +} + +func decryptJSONWebKey(key string) (*jose.JSONWebKey, error) { + enc, err := jose.ParseEncrypted(key) + if err != nil { + return nil, err + } + b, err := enc.Decrypt([]byte("password")) + if err != nil { + return nil, err + } + jwk := new(jose.JSONWebKey) + if err := json.Unmarshal(b, jwk); err != nil { + return nil, err + } + return jwk, nil +} + +func generateJWK() (*JWK, error) { + name, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + jwk, err := generateJSONWebKey() + if err != nil { + return nil, err + } + jwe, err := encryptJSONWebKey(jwk) + if err != nil { + return nil, err + } + public := jwk.Public() + encrypted, err := jwe.CompactSerialize() + if err != nil { + return nil, err + } + return &JWK{ + Name: name, + Type: "JWK", + Key: &public, + EncryptedKey: encrypted, + Claims: &globalProvisionerClaims, + audiences: testAudiences, + }, nil +} + +func generateOIDC() (*OIDC, error) { + name, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + clientID, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + issuer, err := randutil.Alphanumeric(10) + if err != nil { + return nil, err + } + jwk, err := generateJSONWebKey() + if err != nil { + return nil, err + } + return &OIDC{ + Name: name, + Type: "OIDC", + ClientID: clientID, + ConfigurationEndpoint: "https://example.com/.well-known/openid-configuration", + Claims: &globalProvisionerClaims, + configuration: openIDConfiguration{ + Issuer: issuer, + JWKSetURI: "https://example.com/.well-known/jwks", + }, + keyStore: &keyStore{ + keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, + expiry: time.Now().Add(24 * time.Hour), + }, + }, nil +} + +func generateCollection(nJWK, nOIDC int) (*Collection, error) { + col := NewCollection(testAudiences) + for i := 0; i < nJWK; i++ { + p, err := generateJWK() + if err != nil { + return nil, err + } + col.Store(p) + } + for i := 0; i < nOIDC; i++ { + p, err := generateOIDC() + if err != nil { + return nil, err + } + col.Store(p) + } + return col, nil +} + +func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { + return generateToken("subject", iss, aud, "name@smallstep.com", []string{"test.smallstep.com"}, time.Now(), jwk) +} + +func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { + sig, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, + new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), + ) + if err != nil { + return "", err + } + + id, err := randutil.ASCII(64) + if err != nil { + return "", err + } + + claims := struct { + jose.Claims + Email string `json:"email"` + SANS []string `json:"sans"` + }{ + Claims: jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + }, + Email: email, + SANS: sans, + } + return jose.Signed(sig).Claims(claims).CompactSerialize() +} + +func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) { + tok, err := jose.ParseSigned(token) + if err != nil { + return nil, nil, err + } + claims := new(jose.Claims) + if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil { + return nil, nil, err + } + return tok, claims, nil +} + +func generateJWKServer(n int) *httptest.Server { + hits := struct { + Hits int `json:"hits"` + }{} + writeJSON := func(w http.ResponseWriter, v interface{}) { + b, err := json.Marshal(v) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(b) + } + getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet { + var ret jose.JSONWebKeySet + for _, k := range ks.Keys { + ret.Keys = append(ret.Keys, k.Public()) + } + return ret + } + + defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) + srv := httptest.NewUnstartedServer(nil) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits.Hits++ + switch r.RequestURI { + case "/error": + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + case "/hits": + writeJSON(w, hits) + case "/openid-configuration", "/.well-known/openid-configuration": + writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"}) + case "/random": + keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) + w.Header().Add("Cache-Control", "max-age=5") + writeJSON(w, getPublic(keySet)) + case "/private": + writeJSON(w, defaultKeySet) + default: + w.Header().Add("Cache-Control", "max-age=5") + writeJSON(w, getPublic(defaultKeySet)) + } + }) + + srv.Start() + return srv +} diff --git a/authority/provisioner_test.go b/authority/provisioner_test.go deleted file mode 100644 index 5135636e..00000000 --- a/authority/provisioner_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package authority - -import ( - "errors" - "testing" - - "github.com/smallstep/assert" - jose "gopkg.in/square/go-jose.v2" -) - -func TestProvisionerInit(t *testing.T) { - type ProvisionerValidateTest struct { - p *Provisioner - err error - } - tests := map[string]func(*testing.T) ProvisionerValidateTest{ - "fail-empty-name": func(t *testing.T) ProvisionerValidateTest { - return ProvisionerValidateTest{ - p: &Provisioner{}, - err: errors.New("provisioner name cannot be empty"), - } - }, - "fail-empty-type": func(t *testing.T) ProvisionerValidateTest { - return ProvisionerValidateTest{ - p: &Provisioner{Name: "foo"}, - err: errors.New("provisioner type cannot be empty"), - } - }, - "fail-empty-key": func(t *testing.T) ProvisionerValidateTest { - return ProvisionerValidateTest{ - p: &Provisioner{Name: "foo", Type: "bar"}, - err: errors.New("provisioner key cannot be empty"), - } - }, - "ok": func(t *testing.T) ProvisionerValidateTest { - return ProvisionerValidateTest{ - p: &Provisioner{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, - } - }, - } - - for name, get := range tests { - t.Run(name, func(t *testing.T) { - tc := get(t) - err := tc.p.Init(&globalProvisionerClaims) - if err != nil { - if assert.NotNil(t, tc.err) { - assert.Equals(t, tc.err.Error(), err.Error()) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} diff --git a/authority/provisioners.go b/authority/provisioners.go index 85713b7e..d3072d12 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -1,115 +1,25 @@ package authority import ( - "crypto/sha1" - "encoding/binary" - "encoding/hex" - "encoding/json" - "fmt" - "math" "net/http" - "sort" - "strings" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" ) -// DefaultProvisionersLimit is the default limit for listing provisioners. -const DefaultProvisionersLimit = 20 - -// DefaultProvisionersMax is the maximum limit for listing provisioners. -const DefaultProvisionersMax = 100 - // GetEncryptedKey returns the JWE key corresponding to the given kid argument. func (a *Authority) GetEncryptedKey(kid string) (string, error) { - val, ok := a.encryptedKeyIndex.Load(kid) + key, ok := a.provisioners.LoadEncryptedKey(kid) if !ok { return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid), http.StatusNotFound, context{}} } - - key, ok := val.(string) - if !ok { - return "", &apiError{errors.Errorf("stored value is not a string"), - http.StatusInternalServerError, context{}} - } return key, nil } // GetProvisioners returns a map listing each provisioner and the JWK Key Set // with their public keys. -func (a *Authority) GetProvisioners(cursor string, limit int) ([]*Provisioner, string, error) { - provisioners, nextCursor := a.sortedProvisioners.Find(cursor, limit) +func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) { + provisioners, nextCursor := a.provisioners.Find(cursor, limit) return provisioners, nextCursor, nil } - -type uidProvisioner struct { - provisioner *Provisioner - uid string -} - -func newSortedProvisioners(provisioners []*Provisioner) (provisionerSlice, error) { - if len(provisioners) > math.MaxInt32 { - return nil, errors.New("too many provisioners") - } - - var slice provisionerSlice - bi := make([]byte, 4) - for i, p := range provisioners { - sum, err := provisionerSum(p) - if err != nil { - return nil, err - } - // Use the first 4 bytes (32bit) of the sum to insert the order - // Using big endian format to get the strings sorted: - // 0x00000000, 0x00000001, 0x00000002, ... - binary.BigEndian.PutUint32(bi, uint32(i)) - sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3] - bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0 - slice = append(slice, uidProvisioner{ - provisioner: p, - uid: hex.EncodeToString(sum), - }) - } - sort.Sort(slice) - return slice, nil -} - -type provisionerSlice []uidProvisioner - -func (p provisionerSlice) Len() int { return len(p) } -func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } -func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } - -func (p provisionerSlice) Find(cursor string, limit int) ([]*Provisioner, string) { - switch { - case limit <= 0: - limit = DefaultProvisionersLimit - case limit > DefaultProvisionersMax: - limit = DefaultProvisionersMax - } - - n := len(p) - cursor = fmt.Sprintf("%040s", cursor) - i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor }) - - var slice []*Provisioner - for ; i < n && len(slice) < limit; i++ { - slice = append(slice, p[i].provisioner) - } - if i < n { - return slice, strings.TrimLeft(p[i].uid, "0") - } - return slice, "" -} - -// provisionerSum returns the SHA1 of the json representation of the -// provisioner. From this we will create the unique and sorted id. -func provisionerSum(p *Provisioner) ([]byte, error) { - b, err := json.Marshal(p.Key) - if err != nil { - return nil, errors.Wrap(err, "error marshalling provisioner") - } - sum := sha1.Sum(b) - return sum[:], nil -} diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index b982a366..303c4e8a 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -1,16 +1,12 @@ package authority import ( - "encoding/json" "net/http" - "reflect" - "strings" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/cli/crypto/randutil" - "github.com/smallstep/cli/jose" + "github.com/smallstep/certificates/authority/provisioner" ) func TestGetEncryptedKey(t *testing.T) { @@ -27,7 +23,7 @@ func TestGetEncryptedKey(t *testing.T) { assert.FatalError(t, err) return &ek{ a: a, - kid: c.AuthorityConfig.Provisioners[1].Key.KeyID, + kid: c.AuthorityConfig.Provisioners[1].(*provisioner.JWK).Key.KeyID, } }, "fail-not-found": func(t *testing.T) *ek { @@ -42,19 +38,6 @@ func TestGetEncryptedKey(t *testing.T) { http.StatusNotFound, context{}}, } }, - "fail-invalid-type-found": func(t *testing.T) *ek { - c, err := LoadConfiguration("../ca/testdata/ca.json") - assert.FatalError(t, err) - a, err := New(c) - assert.FatalError(t, err) - a.encryptedKeyIndex.Store("foo", 5) - return &ek{ - a: a, - kid: "foo", - err: &apiError{errors.Errorf("stored value is not a string"), - http.StatusInternalServerError, context{}}, - } - }, } for name, genTestCase := range tests { @@ -75,9 +58,9 @@ func TestGetEncryptedKey(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - val, ok := tc.a.provisionerIDIndex.Load("max:" + tc.kid) + val, ok := tc.a.provisioners.Load("max:" + tc.kid) assert.Fatal(t, ok) - p, ok := val.(*Provisioner) + p, ok := val.(*provisioner.JWK) assert.Fatal(t, ok) assert.Equals(t, p.EncryptedKey, ek) } @@ -126,102 +109,3 @@ func TestGetProvisioners(t *testing.T) { }) } } - -func generateProvisioner(t *testing.T) *Provisioner { - name, err := randutil.Alphanumeric(10) - assert.FatalError(t, err) - // Create a new JWK - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - // Encrypt JWK - salt, err := randutil.Salt(jose.PBKDF2SaltSize) - assert.FatalError(t, err) - b, err := json.Marshal(jwk) - assert.FatalError(t, err) - recipient := jose.Recipient{ - Algorithm: jose.PBES2_HS256_A128KW, - Key: []byte("password"), - PBES2Count: jose.PBKDF2Iterations, - PBES2Salt: salt, - } - opts := new(jose.EncrypterOptions) - opts.WithContentType(jose.ContentType("jwk+json")) - encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts) - assert.FatalError(t, err) - jwe, err := encrypter.Encrypt(b) - assert.FatalError(t, err) - // get public and encrypted keys - public := jwk.Public() - encrypted, err := jwe.CompactSerialize() - assert.FatalError(t, err) - return &Provisioner{ - Name: name, - Type: "JWT", - Key: &public, - EncryptedKey: encrypted, - } -} - -func Test_newSortedProvisioners(t *testing.T) { - provisioners := make([]*Provisioner, 20) - for i := range provisioners { - provisioners[i] = generateProvisioner(t) - } - - ps, err := newSortedProvisioners(provisioners) - assert.FatalError(t, err) - prev := "" - for i, p := range ps { - if p.uid < prev { - t.Errorf("%s should be less that %s", p.uid, prev) - } - if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID { - t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID) - } - prev = p.uid - } -} - -func Test_provisionerSlice_Find(t *testing.T) { - trim := func(s string) string { - return strings.TrimLeft(s, "0") - } - provisioners := make([]*Provisioner, 20) - for i := range provisioners { - provisioners[i] = generateProvisioner(t) - } - ps, err := newSortedProvisioners(provisioners) - assert.FatalError(t, err) - - type args struct { - cursor string - limit int - } - tests := []struct { - name string - p provisionerSlice - args args - want []*Provisioner - want1 string - }{ - {"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""}, - {"0 to 19", ps, args{"", 20}, provisioners[0:20], ""}, - {"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)}, - {"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""}, - {"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)}, - {"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)}, - {"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""}, - {"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want) - } - if got1 != tt.want1 { - t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1) - } - }) - } -} diff --git a/authority/tls.go b/authority/tls.go index ad5ea808..c52ac1e8 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -3,7 +3,6 @@ package authority import ( "crypto/tls" "crypto/x509" - "crypto/x509/pkix" "encoding/asn1" "encoding/pem" "net/http" @@ -11,6 +10,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/x509util" @@ -22,48 +22,7 @@ func (a *Authority) GetTLSOptions() *tlsutil.TLSOptions { return a.config.TLS } -// SignOptions contains the options that can be passed to the Authority.Sign -// method. -type SignOptions struct { - NotAfter time.Time `json:"notAfter"` - NotBefore time.Time `json:"notBefore"` -} - -var ( - stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} - stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) - oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35} -) - -type stepProvisionerASN1 struct { - Type int - Name []byte - CredentialID []byte -} - -const provisionerTypeJWK = 1 - -func withProvisionerOID(name, kid string) x509util.WithOption { - return func(p x509util.Profile) error { - crt := p.Subject() - - b, err := asn1.Marshal(stepProvisionerASN1{ - Type: provisionerTypeJWK, - Name: []byte(name), - CredentialID: []byte(kid), - }) - if err != nil { - return err - } - crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{ - Id: stepOIDProvisioner, - Critical: false, - Value: b, - }) - - return nil - } -} +var oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35} func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption { return func(p x509util.Profile) error { @@ -96,28 +55,22 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption { } // Sign creates a signed certificate from a certificate signing request. -func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) { +func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) { var ( - errContext = context{"csr": csr, "signOptions": signOpts} - claims = []certClaim{} - mods = []x509util.WithOption{} + errContext = context{"csr": csr, "signOptions": signOpts} + mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)} + certValidators = []provisioner.CertificateValidator{} ) for _, op := range extraOpts { switch k := op.(type) { - case certClaim: - claims = append(claims, k) - case x509util.WithOption: - mods = append(mods, k) - case *Provisioner: - m, c, err := k.getTLSApps(signOpts) - if err != nil { - return nil, nil, &apiError{err, http.StatusInternalServerError, errContext} + case provisioner.CertificateValidator: + certValidators = append(certValidators, k) + case provisioner.CertificateRequestValidator: + if err := k.Valid(csr); err != nil { + return nil, nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext} } - mods = append(mods, m...) - mods = append(mods, []x509util.WithOption{ - withDefaultASN1DN(a.config.AuthorityConfig.Template), - }...) - claims = append(claims, c...) + case provisioner.ProfileModifier: + mods = append(mods, k.Option(signOpts)) default: return nil, nil, &apiError{errors.Errorf("sign: invalid extra option type %T", k), http.StatusInternalServerError, errContext} @@ -137,10 +90,6 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, ext return nil, nil, &apiError{errors.Wrapf(err, "sign"), http.StatusInternalServerError, errContext} } - if err := validateClaims(leaf.Subject(), claims); err != nil { - return nil, nil, &apiError{errors.Wrapf(err, "sign"), http.StatusUnauthorized, errContext} - } - crtBytes, err := leaf.CreateCertificate() if err != nil { return nil, nil, &apiError{errors.Wrap(err, "sign: error creating new leaf certificate"), @@ -153,6 +102,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, ext http.StatusInternalServerError, errContext} } + // FIXME: This should be before creating the certificate. + for _, v := range certValidators { + if err := v.Valid(serverCert); err != nil { + return nil, nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext} + } + } + caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw) if err != nil { return nil, nil, &apiError{errors.Wrap(err, "sign: error parsing intermediate certificate"), diff --git a/authority/tls_test.go b/authority/tls_test.go index 70b3d7a1..47ac7966 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -7,7 +7,6 @@ import ( "crypto/x509/pkix" "encoding/asn1" "fmt" - "net" "net/http" "reflect" "testing" @@ -15,12 +14,49 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/crypto/x509util" + "github.com/smallstep/cli/jose" stepx509 "github.com/smallstep/cli/pkg/x509" ) +var ( + stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} + stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) +) + +const provisionerTypeJWK = 1 + +type stepProvisionerASN1 struct { + Type int + Name []byte + CredentialID []byte +} + +func withProvisionerOID(name, kid string) x509util.WithOption { + return func(p x509util.Profile) error { + crt := p.Subject() + + b, err := asn1.Marshal(stepProvisionerASN1{ + Type: provisionerTypeJWK, + Name: []byte(name), + CredentialID: []byte(kid), + }) + if err != nil { + return err + } + crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{ + Id: stepOIDProvisioner, + Critical: false, + Value: b, + }) + + return nil + } +} + func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateRequest)) *x509.CertificateRequest { _csr := &x509.CertificateRequest{ Subject: pkix.Name{CommonName: "smallstep test"}, @@ -52,24 +88,25 @@ func TestSign(t *testing.T) { } nb := time.Now() - signOpts := SignOptions{ + signOpts := provisioner.Options{ NotBefore: nb, NotAfter: nb.Add(time.Minute * 5), } - p := a.config.AuthorityConfig.Provisioners[1] - extraOpts := []interface{}{ - &commonNameClaim{"smallstep test"}, - &dnsNamesClaim{[]string{"test.smallstep.com"}}, - &ipAddressesClaim{[]net.IP{}}, - p, - } + // Create a token to get test extra opts. + p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK) + key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) + assert.FatalError(t, err) + token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key) + assert.FatalError(t, err) + extraOpts, err := a.Authorize(token) + assert.FatalError(t, err) type signTest struct { auth *Authority csr *x509.CertificateRequest - signOpts SignOptions - extraOpts []interface{} + signOpts provisioner.Options + extraOpts []provisioner.SignOption err *apiError } tests := map[string]func(*testing.T) *signTest{ @@ -123,7 +160,7 @@ func TestSign(t *testing.T) { return &signTest{ auth: _a, csr: csr, - extraOpts: []interface{}{p}, + extraOpts: extraOpts, signOpts: signOpts, err: &apiError{errors.New("sign: error creating new leaf certificate"), http.StatusInternalServerError, @@ -133,7 +170,7 @@ func TestSign(t *testing.T) { }, "fail provisioner duration claim": func(t *testing.T) *signTest { csr := getCSR(t, priv) - _signOpts := SignOptions{ + _signOpts := provisioner.Options{ NotBefore: nb, NotAfter: nb.Add(time.Hour * 25), } @@ -157,7 +194,7 @@ func TestSign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: &apiError{errors.New("sign: DNS names claim failed - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), + err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), http.StatusUnauthorized, context{"csr": csr, "signOptions": signOpts}, }, @@ -262,7 +299,7 @@ func TestRenew(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) na1 := now - so := &SignOptions{ + so := &provisioner.Options{ NotBefore: nb1, NotAfter: na1, } @@ -272,7 +309,7 @@ func TestRenew(t *testing.T) { x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0), withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"), - withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].Key.KeyID)) + withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID)) assert.FatalError(t, err) crtBytes, err := leaf.CreateCertificate() assert.FatalError(t, err) @@ -284,7 +321,7 @@ func TestRenew(t *testing.T) { x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0), withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"), - withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), + withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID), ) assert.FatalError(t, err) crtBytesNoRenew, err := leafNoRenew.CreateCertificate() @@ -321,7 +358,7 @@ func TestRenew(t *testing.T) { } return &renewTest{ crt: crtNoRenew, - err: &apiError{errors.New("renew disabled"), + err: &apiError{errors.New("renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), http.StatusUnauthorized, ctx}, }, nil }, diff --git a/authority/types.go b/authority/types.go index f0a781d5..0d0f2a90 100644 --- a/authority/types.go +++ b/authority/types.go @@ -2,48 +2,10 @@ package authority import ( "encoding/json" - "time" "github.com/pkg/errors" ) -// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal. -type Duration struct { - time.Duration -} - -// MarshalJSON parses a duration string and sets it to the duration. -// -// A duration string is a possibly signed sequence of decimal numbers, each with -// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". -// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *Duration) MarshalJSON() ([]byte, error) { - return json.Marshal(d.Duration.String()) -} - -// UnmarshalJSON parses a duration string and sets it to the duration. -// -// A duration string is a possibly signed sequence of decimal numbers, each with -// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". -// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *Duration) UnmarshalJSON(data []byte) (err error) { - var ( - s string - _d time.Duration - ) - if d == nil { - return errors.New("duration cannot be nil") - } - if err = json.Unmarshal(data, &s); err != nil { - return errors.Wrapf(err, "error unmarshalling %s", data) - } - if _d, err = time.ParseDuration(s); err != nil { - return errors.Wrapf(err, "error parsing %s as duration", s) - } - d.Duration = _d - return -} - // multiString represents a type that can be encoded/decoded in JSON as a single // string or an array of strings. type multiString []string diff --git a/authority/types_test.go b/authority/types_test.go index c49c368f..352c253f 100644 --- a/authority/types_test.go +++ b/authority/types_test.go @@ -3,7 +3,6 @@ package authority import ( "reflect" "testing" - "time" ) func Test_multiString_First(t *testing.T) { @@ -101,57 +100,3 @@ func Test_multiString_UnmarshalJSON(t *testing.T) { }) } } - -func TestDuration_UnmarshalJSON(t *testing.T) { - type args struct { - data []byte - } - tests := []struct { - name string - d *Duration - args args - want *Duration - wantErr bool - }{ - {"empty", new(Duration), args{[]byte{}}, new(Duration), true}, - {"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true}, - {"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true}, - {"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true}, - {"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false}, - {"nil", nil, args{nil}, nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { - t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(tt.d, tt.want) { - t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want) - } - }) - } -} - -func Test_duration_MarshalJSON(t *testing.T) { - tests := []struct { - name string - d *Duration - want []byte - wantErr bool - }{ - {"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.d.MarshalJSON() - if (err != nil) != tt.wantErr { - t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/autocert/controller/provisioner.go b/autocert/controller/provisioner.go index 453b87e8..857127ad 100644 --- a/autocert/controller/provisioner.go +++ b/autocert/controller/provisioner.go @@ -7,7 +7,7 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority" + provisioners "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/ca" "github.com/smallstep/cli/config" "github.com/smallstep/cli/crypto/randutil" @@ -111,10 +111,12 @@ func loadProvisionerJWKByName(name, caURL, caRoot, passFile string) (key *jose.J } for _, provisioner := range provisioners { - if provisioner.Name == name { - key, err = decryptProvisionerJWK(provisioner.EncryptedKey, passFile) - if err == nil { - return + if provisioner.GetName() == name { + if _, encryptedKey, ok := provisioner.GetEncryptedKey(); ok { + key, err = decryptProvisionerJWK(encryptedKey, passFile) + if err == nil { + return + } } } } @@ -154,7 +156,7 @@ func getRootCAPath() string { } // getProvisioners returns the map of provisioners on the given CA. -func getProvisioners(caURL, rootFile string) ([]*authority.Provisioner, error) { +func getProvisioners(caURL, rootFile string) (provisioners.List, error) { if len(rootFile) == 0 { rootFile = getRootCAPath() } @@ -163,7 +165,7 @@ func getProvisioners(caURL, rootFile string) ([]*authority.Provisioner, error) { return nil, err } cursor := "" - provisioners := []*authority.Provisioner{} + var provisioners provisioners.List for { resp, err := client.Provisioners(ca.WithProvisionerCursor(cursor), ca.WithProvisionerLimit(100)) if err != nil { diff --git a/ca/ca_test.go b/ca/ca_test.go index 32701c05..d5fc17f7 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -20,6 +20,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/randutil" @@ -389,7 +390,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) { } }, "ok": func(t *testing.T) *ekt { - p := config.AuthorityConfig.Provisioners[2] + p := config.AuthorityConfig.Provisioners[2].(*provisioner.JWK) return &ekt{ ca: ca, kid: p.Key.KeyID, diff --git a/ca/client.go b/ca/client.go index 2a0e750a..dfe63a3f 100644 --- a/ca/client.go +++ b/ca/client.go @@ -446,7 +446,11 @@ func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) return nil, nil, errors.Wrap(err, "error generating key") } + var emails []string dnsNames, ips := x509util.SplitSANs(claims.SANs) + if claims.Email != "" { + emails = append(emails, claims.Email) + } template := &x509.CertificateRequest{ Subject: pkix.Name{ @@ -455,6 +459,7 @@ func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) SignatureAlgorithm: x509.ECDSAWithSHA256, DNSNames: dnsNames, IPAddresses: ips, + EmailAddresses: emails, } csr, err := x509.CreateCertificateRequest(rand.Reader, template, pk) diff --git a/ca/client_test.go b/ca/client_test.go index d82afa31..68fefd09 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -14,7 +14,7 @@ import ( "time" "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" ) const ( @@ -391,7 +391,7 @@ func TestClient_Renew(t *testing.T) { func TestClient_Provisioners(t *testing.T) { ok := &api.ProvisionersResponse{ - Provisioners: []*authority.Provisioner{}, + Provisioners: provisioner.List{}, } internalServerError := api.InternalServerError(fmt.Errorf("Internal Server Error")) diff --git a/cmd/step-ca/main.go b/cmd/step-ca/main.go index 6bdf3497..2513b552 100644 --- a/cmd/step-ca/main.go +++ b/cmd/step-ca/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "flag" "fmt" "io/ioutil" "log" @@ -143,40 +144,12 @@ intermediate private key.`, } app.Action = func(ctx *cli.Context) error { - passFile := ctx.String("password-file") - - // If zero cmd line args show help, if >1 cmd line args show error. - if ctx.NArg() == 0 { - return cli.ShowAppHelp(ctx) - } - if err := errs.NumberOfArguments(ctx, 1); err != nil { - return err - } - - configFile := ctx.Args().Get(0) - config, err := authority.LoadConfiguration(configFile) - if err != nil { - fatal(err) - } - - var password []byte - if passFile != "" { - if password, err = ioutil.ReadFile(passFile); err != nil { - fatal(errors.Wrapf(err, "error reading %s", passFile)) - } - password = bytes.TrimRightFunc(password, unicode.IsSpace) - } - - srv, err := ca.New(config, ca.WithConfigFile(configFile), ca.WithPassword(password)) - if err != nil { - fatal(err) - } - - go ca.StopReloaderHandler(srv) - if err = srv.Run(); err != nil && err != http.ErrServerClosed { - fatal(err) - } - return nil + // Hack to be able to run a the top action as a subcommand + cmd := cli.Command{Name: "start", Action: startAction, Flags: app.Flags} + set := flag.NewFlagSet(app.Name, flag.ContinueOnError) + set.Parse(os.Args) + ctx = cli.NewContext(app, set, nil) + return cmd.Run(ctx) } if err := app.Run(os.Args); err != nil { @@ -189,6 +162,43 @@ intermediate private key.`, } } +func startAction(ctx *cli.Context) error { + passFile := ctx.String("password-file") + + // If zero cmd line args show help, if >1 cmd line args show error. + if ctx.NArg() == 0 { + return cli.ShowAppHelp(ctx) + } + if err := errs.NumberOfArguments(ctx, 1); err != nil { + return err + } + + configFile := ctx.Args().Get(0) + config, err := authority.LoadConfiguration(configFile) + if err != nil { + fatal(err) + } + + var password []byte + if passFile != "" { + if password, err = ioutil.ReadFile(passFile); err != nil { + fatal(errors.Wrapf(err, "error reading %s", passFile)) + } + password = bytes.TrimRightFunc(password, unicode.IsSpace) + } + + srv, err := ca.New(config, ca.WithConfigFile(configFile), ca.WithPassword(password)) + if err != nil { + fatal(err) + } + + go ca.StopReloaderHandler(srv) + if err = srv.Run(); err != nil && err != http.ErrServerClosed { + fatal(err) + } + return nil +} + // fatal writes the passed error on the standard error and exits with the exit // code 1. If the environment variable STEPDEBUG is set to 1 it shows the // stack trace of the error.