forked from TrueCloudLab/certificates
Merge pull request #51 from smallstep/oidc-provisioner
OIDC provisioners
This commit is contained in:
commit
095ab891e7
40 changed files with 3080 additions and 1234 deletions
18
api/api.go
18
api/api.go
|
@ -18,19 +18,19 @@ import (
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"github.com/smallstep/cli/crypto/tlsutil"
|
"github.com/smallstep/cli/crypto/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authority is the interface implemented by a CA authority.
|
// Authority is the interface implemented by a CA authority.
|
||||||
type Authority interface {
|
type Authority interface {
|
||||||
Authorize(ott string) ([]interface{}, error)
|
Authorize(ott string) ([]provisioner.SignOption, error)
|
||||||
GetTLSOptions() *tlsutil.TLSOptions
|
GetTLSOptions() *tlsutil.TLSOptions
|
||||||
Root(shasum string) (*x509.Certificate, error)
|
Root(shasum string) (*x509.Certificate, error)
|
||||||
Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
|
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||||
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||||
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
|
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
||||||
GetEncryptedKey(kid string) (string, error)
|
GetEncryptedKey(kid string) (string, error)
|
||||||
GetRoots() (federation []*x509.Certificate, err error)
|
GetRoots() (federation []*x509.Certificate, err error)
|
||||||
GetFederation() ([]*x509.Certificate, error)
|
GetFederation() ([]*x509.Certificate, error)
|
||||||
|
@ -161,11 +161,11 @@ type SignRequest struct {
|
||||||
// ProvisionersResponse is the response object that returns the list of
|
// ProvisionersResponse is the response object that returns the list of
|
||||||
// provisioners.
|
// provisioners.
|
||||||
type ProvisionersResponse struct {
|
type ProvisionersResponse struct {
|
||||||
Provisioners []*authority.Provisioner `json:"provisioners"`
|
Provisioners provisioner.List `json:"provisioners"`
|
||||||
NextCursor string `json:"nextCursor"`
|
NextCursor string `json:"nextCursor"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvisionerKeyResponse is the response object that returns the encryptoed key
|
// ProvisionerKeyResponse is the response object that returns the encrypted key
|
||||||
// of a provisioner.
|
// of a provisioner.
|
||||||
type ProvisionerKeyResponse struct {
|
type ProvisionerKeyResponse struct {
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
|
@ -266,18 +266,18 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
signOpts := authority.SignOptions{
|
opts := provisioner.Options{
|
||||||
NotBefore: body.NotBefore,
|
NotBefore: body.NotBefore,
|
||||||
NotAfter: body.NotAfter,
|
NotAfter: body.NotAfter,
|
||||||
}
|
}
|
||||||
|
|
||||||
extraOpts, err := h.Authority.Authorize(body.OTT)
|
signOpts, err := h.Authority.Authorize(body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, Unauthorized(err))
|
WriteError(w, Unauthorized(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, root, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, signOpts, extraOpts...)
|
cert, root, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, Forbidden(err))
|
WriteError(w, Forbidden(err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -24,7 +24,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"github.com/smallstep/cli/crypto/tlsutil"
|
"github.com/smallstep/cli/crypto/tlsutil"
|
||||||
"github.com/smallstep/cli/jose"
|
"github.com/smallstep/cli/jose"
|
||||||
|
@ -410,22 +410,22 @@ func TestSignRequest_Validate(t *testing.T) {
|
||||||
type mockAuthority struct {
|
type mockAuthority struct {
|
||||||
ret1, ret2 interface{}
|
ret1, ret2 interface{}
|
||||||
err error
|
err error
|
||||||
authorize func(ott string) ([]interface{}, error)
|
authorize func(ott string) ([]provisioner.SignOption, error)
|
||||||
getTLSOptions func() *tlsutil.TLSOptions
|
getTLSOptions func() *tlsutil.TLSOptions
|
||||||
root func(shasum string) (*x509.Certificate, error)
|
root func(shasum string) (*x509.Certificate, error)
|
||||||
sign func(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
|
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||||
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||||
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
|
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||||
getEncryptedKey func(kid string) (string, error)
|
getEncryptedKey func(kid string) (string, error)
|
||||||
getRoots func() ([]*x509.Certificate, error)
|
getRoots func() ([]*x509.Certificate, error)
|
||||||
getFederation func() ([]*x509.Certificate, error)
|
getFederation func() ([]*x509.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) {
|
func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) {
|
||||||
if m.authorize != nil {
|
if m.authorize != nil {
|
||||||
return m.authorize(ott)
|
return m.authorize(ott)
|
||||||
}
|
}
|
||||||
return m.ret1.([]interface{}), m.err
|
return m.ret1.([]provisioner.SignOption), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetTLSOptions() *tlsutil.TLSOptions {
|
func (m *mockAuthority) GetTLSOptions() *tlsutil.TLSOptions {
|
||||||
|
@ -442,9 +442,9 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
|
||||||
return m.ret1.(*x509.Certificate), m.err
|
return m.ret1.(*x509.Certificate), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) {
|
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
|
||||||
if m.sign != nil {
|
if m.sign != nil {
|
||||||
return m.sign(cr, signOpts, extraOpts...)
|
return m.sign(cr, opts, signOpts...)
|
||||||
}
|
}
|
||||||
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
||||||
}
|
}
|
||||||
|
@ -456,11 +456,11 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.
|
||||||
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) ([]*authority.Provisioner, string, error) {
|
func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) {
|
||||||
if m.getProvisioners != nil {
|
if m.getProvisioners != nil {
|
||||||
return m.getProvisioners(nextCursor, limit)
|
return m.getProvisioners(nextCursor, limit)
|
||||||
}
|
}
|
||||||
return m.ret1.([]*authority.Provisioner), m.ret2.(string), m.err
|
return m.ret1.(provisioner.List), m.ret2.(string), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
|
func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
|
||||||
|
@ -597,7 +597,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input string
|
input string
|
||||||
certAttrOpts []interface{}
|
certAttrOpts []provisioner.SignOption
|
||||||
autherr error
|
autherr error
|
||||||
cert *x509.Certificate
|
cert *x509.Certificate
|
||||||
root *x509.Certificate
|
root *x509.Certificate
|
||||||
|
@ -617,7 +617,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
h := New(&mockAuthority{
|
||||||
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
||||||
authorize: func(ott string) ([]interface{}, error) {
|
authorize: func(ott string) ([]provisioner.SignOption, error) {
|
||||||
return tt.certAttrOpts, tt.autherr
|
return tt.certAttrOpts, tt.autherr
|
||||||
},
|
},
|
||||||
getTLSOptions: func() *tlsutil.TLSOptions {
|
getTLSOptions: func() *tlsutil.TLSOptions {
|
||||||
|
@ -723,14 +723,14 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p := []*authority.Provisioner{
|
p := provisioner.List{
|
||||||
{
|
&provisioner.JWK{
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Name: "max",
|
Name: "max",
|
||||||
EncryptedKey: "abc",
|
EncryptedKey: "abc",
|
||||||
Key: &key,
|
Key: &key,
|
||||||
},
|
},
|
||||||
{
|
&provisioner.JWK{
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Name: "mariano",
|
Name: "mariano",
|
||||||
EncryptedKey: "def",
|
EncryptedKey: "def",
|
||||||
|
|
|
@ -4,10 +4,10 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/cli/crypto/pemutil"
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
"github.com/smallstep/cli/crypto/x509util"
|
"github.com/smallstep/cli/crypto/x509util"
|
||||||
)
|
)
|
||||||
|
@ -23,11 +23,7 @@ type Authority struct {
|
||||||
certificates *sync.Map
|
certificates *sync.Map
|
||||||
ottMap *sync.Map
|
ottMap *sync.Map
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
provisionerIDIndex *sync.Map
|
provisioners *provisioner.Collection
|
||||||
encryptedKeyIndex *sync.Map
|
|
||||||
provisionerKeySetIndex *sync.Map
|
|
||||||
sortedProvisioners provisionerSlice
|
|
||||||
audiences []string
|
|
||||||
// Do not re-initialize
|
// Do not re-initialize
|
||||||
initOnce bool
|
initOnce bool
|
||||||
}
|
}
|
||||||
|
@ -39,31 +35,11 @@ func New(config *Config) (*Authority, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get sorted provisioners
|
|
||||||
var sorted provisionerSlice
|
|
||||||
if config.AuthorityConfig != nil {
|
|
||||||
sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define audiences: legacy + possible urls without the ports.
|
|
||||||
// The CA might have proxies in front so we cannot rely on the port.
|
|
||||||
audiences := []string{legacyAuthority}
|
|
||||||
for _, name := range config.DNSNames {
|
|
||||||
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
|
|
||||||
}
|
|
||||||
|
|
||||||
var a = &Authority{
|
var a = &Authority{
|
||||||
config: config,
|
config: config,
|
||||||
certificates: new(sync.Map),
|
certificates: new(sync.Map),
|
||||||
ottMap: new(sync.Map),
|
ottMap: new(sync.Map),
|
||||||
provisionerIDIndex: new(sync.Map),
|
provisioners: provisioner.NewCollection(config.getAudiences()),
|
||||||
encryptedKeyIndex: new(sync.Map),
|
|
||||||
provisionerKeySetIndex: new(sync.Map),
|
|
||||||
sortedProvisioners: sorted,
|
|
||||||
audiences: audiences,
|
|
||||||
}
|
}
|
||||||
if err := a.init(); err != nil {
|
if err := a.init(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -120,14 +96,15 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store all the provisioners
|
||||||
for _, p := range a.config.AuthorityConfig.Provisioners {
|
for _, p := range a.config.AuthorityConfig.Provisioners {
|
||||||
a.provisionerIDIndex.Store(p.ID(), p)
|
if err := a.provisioners.Store(p); err != nil {
|
||||||
if len(p.EncryptedKey) != 0 {
|
return err
|
||||||
a.encryptedKeyIndex.Store(p.Key.KeyID, p.EncryptedKey)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.startTime = time.Now()
|
// JWT numeric dates are seconds.
|
||||||
|
a.startTime = time.Now().Truncate(time.Second)
|
||||||
// Set flag indicating that initialization has been completed, and should
|
// Set flag indicating that initialization has been completed, and should
|
||||||
// not be repeated.
|
// not be repeated.
|
||||||
a.initOnce = true
|
a.initOnce = true
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
stepJOSE "github.com/smallstep/cli/jose"
|
stepJOSE "github.com/smallstep/cli/jose"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,22 +17,22 @@ func testAuthority(t *testing.T) *Authority {
|
||||||
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
disableRenewal := true
|
disableRenewal := true
|
||||||
p := []*Provisioner{
|
p := provisioner.List{
|
||||||
{
|
&provisioner.JWK{
|
||||||
Name: "Max",
|
Name: "Max",
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Key: maxjwk,
|
Key: maxjwk,
|
||||||
},
|
},
|
||||||
{
|
&provisioner.JWK{
|
||||||
Name: "step-cli",
|
Name: "step-cli",
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Key: clijwk,
|
Key: clijwk,
|
||||||
},
|
},
|
||||||
{
|
&provisioner.JWK{
|
||||||
Name: "dev",
|
Name: "dev",
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Key: maxjwk,
|
Key: maxjwk,
|
||||||
Claims: &ProvisionerClaims{
|
Claims: &provisioner.Claims{
|
||||||
DisableRenewal: &disableRenewal,
|
DisableRenewal: &disableRenewal,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -113,24 +114,18 @@ func TestAuthorityNew(t *testing.T) {
|
||||||
assert.True(t, auth.initOnce)
|
assert.True(t, auth.initOnce)
|
||||||
assert.NotNil(t, auth.intermediateIdentity)
|
assert.NotNil(t, auth.intermediateIdentity)
|
||||||
for _, p := range tc.config.AuthorityConfig.Provisioners {
|
for _, p := range tc.config.AuthorityConfig.Provisioners {
|
||||||
_p, ok := auth.provisionerIDIndex.Load(p.ID())
|
_p, ok := auth.provisioners.Load(p.GetID())
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equals(t, p, _p)
|
assert.Equals(t, p, _p)
|
||||||
if len(p.EncryptedKey) > 0 {
|
if kid, encryptedKey, ok := p.GetEncryptedKey(); ok {
|
||||||
key, ok := auth.encryptedKeyIndex.Load(p.Key.KeyID)
|
key, ok := auth.provisioners.LoadEncryptedKey(kid)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equals(t, p.EncryptedKey, key)
|
assert.Equals(t, encryptedKey, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// sanity check
|
// sanity check
|
||||||
_, ok = auth.provisionerIDIndex.Load("fooo")
|
_, ok = auth.provisioners.Load("fooo")
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
|
|
||||||
assert.Equals(t, auth.audiences, []string{
|
|
||||||
"step-certificate-authority",
|
|
||||||
"https://127.0.0.1/sign",
|
|
||||||
"https://127.0.0.1/1.0/sign",
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -2,14 +2,13 @@ package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/asn1"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/cli/crypto/x509util"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"github.com/smallstep/cli/jose"
|
||||||
)
|
)
|
||||||
|
|
||||||
type idUsed struct {
|
type idUsed struct {
|
||||||
|
@ -17,49 +16,21 @@ type idUsed struct {
|
||||||
Subject string `json:"sub,omitempty"`
|
Subject string `json:"sub,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claims extends jwt.Claims with step attributes.
|
// Claims extends jose.Claims with step attributes.
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
jwt.Claims
|
jose.Claims
|
||||||
SANs []string `json:"sans,omitempty"`
|
SANs []string `json:"sans,omitempty"`
|
||||||
}
|
Email string `json:"email,omitempty"`
|
||||||
|
Nonce string `json:"nonce,omitempty"`
|
||||||
// matchesAudience returns true if A and B share at least one element.
|
|
||||||
func matchesAudience(as, bs []string) bool {
|
|
||||||
if len(bs) == 0 || len(as) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, b := range bs {
|
|
||||||
for _, a := range as {
|
|
||||||
if b == a || stripPort(a) == stripPort(b) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// stripPort attempts to strip the port from the given url. If parsing the url
|
|
||||||
// produces errors it will just return the passed argument.
|
|
||||||
func stripPort(rawurl string) string {
|
|
||||||
u, err := url.Parse(rawurl)
|
|
||||||
if err != nil {
|
|
||||||
return rawurl
|
|
||||||
}
|
|
||||||
u.Host = u.Hostname()
|
|
||||||
return u.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authorize authorizes a signature request by validating and authenticating
|
// Authorize authorizes a signature request by validating and authenticating
|
||||||
// a OTT that must be sent w/ the request.
|
// a OTT that must be sent w/ the request.
|
||||||
func (a *Authority) Authorize(ott string) ([]interface{}, error) {
|
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
|
||||||
var (
|
var errContext = map[string]interface{}{"ott": ott}
|
||||||
errContext = map[string]interface{}{"ott": ott}
|
|
||||||
claims = Claims{}
|
|
||||||
)
|
|
||||||
|
|
||||||
// Validate payload
|
// Validate payload
|
||||||
token, err := jwt.ParseSigned(ott)
|
token, err := jose.ParseSigned(ott)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &apiError{errors.Wrapf(err, "authorize: error parsing token"),
|
return nil, &apiError{errors.Wrapf(err, "authorize: error parsing token"),
|
||||||
http.StatusUnauthorized, errContext}
|
http.StatusUnauthorized, errContext}
|
||||||
|
@ -68,86 +39,52 @@ func (a *Authority) Authorize(ott string) ([]interface{}, error) {
|
||||||
// Get claims w/out verification. We need to look up the provisioner
|
// Get claims w/out verification. We need to look up the provisioner
|
||||||
// key in order to verify the claims and we need the issuer from the claims
|
// key in order to verify the claims and we need the issuer from the claims
|
||||||
// before we can look up the provisioner.
|
// before we can look up the provisioner.
|
||||||
|
var claims Claims
|
||||||
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
return nil, &apiError{err, http.StatusUnauthorized, errContext}
|
return nil, &apiError{err, http.StatusUnauthorized, errContext}
|
||||||
}
|
}
|
||||||
kid := token.Headers[0].KeyID // JWT will only have 1 header.
|
|
||||||
if len(kid) == 0 {
|
|
||||||
return nil, &apiError{errors.New("authorize: token KeyID cannot be empty"),
|
|
||||||
http.StatusUnauthorized, errContext}
|
|
||||||
}
|
|
||||||
pid := claims.Issuer + ":" + kid
|
|
||||||
val, ok := a.provisionerIDIndex.Load(pid)
|
|
||||||
if !ok {
|
|
||||||
return nil, &apiError{errors.Errorf("authorize: provisioner with id %s not found", pid),
|
|
||||||
http.StatusUnauthorized, errContext}
|
|
||||||
}
|
|
||||||
p, ok := val.(*Provisioner)
|
|
||||||
if !ok {
|
|
||||||
return nil, &apiError{errors.Errorf("authorize: invalid provisioner type"),
|
|
||||||
http.StatusInternalServerError, errContext}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = token.Claims(p.Key, &claims); err != nil {
|
|
||||||
return nil, &apiError{err, http.StatusUnauthorized, errContext}
|
|
||||||
}
|
|
||||||
|
|
||||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
|
||||||
// more than a few minutes.
|
|
||||||
if err = claims.ValidateWithLeeway(jwt.Expected{
|
|
||||||
Issuer: p.Name,
|
|
||||||
}, time.Minute); err != nil {
|
|
||||||
return nil, &apiError{errors.Wrapf(err, "authorize: invalid token"),
|
|
||||||
http.StatusUnauthorized, errContext}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do not accept tokens issued before the start of the ca.
|
// Do not accept tokens issued before the start of the ca.
|
||||||
// This check is meant as a stopgap solution to the current lack of a persistence layer.
|
// This check is meant as a stopgap solution to the current lack of a persistence layer.
|
||||||
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
|
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
|
||||||
if claims.IssuedAt > 0 && claims.IssuedAt.Time().Before(a.startTime) {
|
if claims.IssuedAt > 0 && claims.IssuedAt.Time().Before(a.startTime) {
|
||||||
return nil, &apiError{errors.New("token issued before the bootstrap of certificate authority"),
|
return nil, &apiError{errors.New("authorize: token issued before the bootstrap of certificate authority"),
|
||||||
http.StatusUnauthorized, errContext}
|
http.StatusUnauthorized, errContext}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !matchesAudience(claims.Audience, a.audiences) {
|
// This method will also validate the audiences for JWK provisioners.
|
||||||
return nil, &apiError{errors.New("authorize: token audience invalid"), http.StatusUnauthorized,
|
p, ok := a.provisioners.LoadByToken(token, &claims.Claims)
|
||||||
errContext}
|
if !ok {
|
||||||
}
|
return nil, &apiError{
|
||||||
|
errors.Errorf("authorize: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")),
|
||||||
if claims.Subject == "" {
|
|
||||||
return nil, &apiError{errors.New("authorize: token subject cannot be empty"),
|
|
||||||
http.StatusUnauthorized, errContext}
|
http.StatusUnauthorized, errContext}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: This is for backwards compatibility with older versions of cli
|
|
||||||
// and certificates. Older versions added the token subject as the only SAN
|
|
||||||
// in a CSR by default.
|
|
||||||
if len(claims.SANs) == 0 {
|
|
||||||
claims.SANs = []string{claims.Subject}
|
|
||||||
}
|
|
||||||
dnsNames, ips := x509util.SplitSANs(claims.SANs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
signOps := []interface{}{
|
|
||||||
&commonNameClaim{claims.Subject},
|
|
||||||
&dnsNamesClaim{dnsNames},
|
|
||||||
&ipAddressesClaim{ips},
|
|
||||||
p,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the token to protect against reuse.
|
// Store the token to protect against reuse.
|
||||||
if _, ok := a.ottMap.LoadOrStore(claims.ID, &idUsed{
|
var reuseKey string
|
||||||
|
switch p.GetType() {
|
||||||
|
case provisioner.TypeJWK:
|
||||||
|
reuseKey = claims.ID
|
||||||
|
case provisioner.TypeOIDC:
|
||||||
|
reuseKey = claims.Nonce
|
||||||
|
}
|
||||||
|
if reuseKey != "" {
|
||||||
|
if _, ok := a.ottMap.LoadOrStore(reuseKey, &idUsed{
|
||||||
UsedAt: time.Now().Unix(),
|
UsedAt: time.Now().Unix(),
|
||||||
Subject: claims.Subject,
|
Subject: claims.Subject,
|
||||||
}); ok {
|
}); ok {
|
||||||
return nil, &apiError{errors.Errorf("token already used"), http.StatusUnauthorized,
|
return nil, &apiError{errors.Errorf("authorize: token already used"), http.StatusUnauthorized, errContext}
|
||||||
errContext}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return signOps, nil
|
// Call the provisioner Authorize method to get the signing options
|
||||||
|
opts, err := p.Authorize(ott)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{errors.Wrap(err, "authorize"), http.StatusUnauthorized, errContext}
|
||||||
|
}
|
||||||
|
|
||||||
|
return opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// authorizeRenewal tries to locate the step provisioner extension, and checks
|
// authorizeRenewal tries to locate the step provisioner extension, and checks
|
||||||
|
@ -157,46 +94,20 @@ func (a *Authority) Authorize(ott string) ([]interface{}, error) {
|
||||||
// TODO(mariano): should we authorize by default?
|
// TODO(mariano): should we authorize by default?
|
||||||
func (a *Authority) authorizeRenewal(crt *x509.Certificate) error {
|
func (a *Authority) authorizeRenewal(crt *x509.Certificate) error {
|
||||||
errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()}
|
errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()}
|
||||||
for _, e := range crt.Extensions {
|
p, ok := a.provisioners.LoadByCertificate(crt)
|
||||||
if e.Id.Equal(stepOIDProvisioner) {
|
|
||||||
var provisioner stepProvisionerASN1
|
|
||||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
|
||||||
return &apiError{
|
|
||||||
err: errors.Wrap(err, "error decoding step provisioner extension"),
|
|
||||||
code: http.StatusInternalServerError,
|
|
||||||
context: errContext,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look for the provisioner, if it cannot be found, renewal will not
|
|
||||||
// be authorized.
|
|
||||||
pid := string(provisioner.Name) + ":" + string(provisioner.CredentialID)
|
|
||||||
val, ok := a.provisionerIDIndex.Load(pid)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return &apiError{
|
return &apiError{
|
||||||
err: errors.Errorf("not found: provisioner %s", pid),
|
err: errors.New("provisioner not found"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
context: errContext,
|
context: errContext,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p, ok := val.(*Provisioner)
|
if err := p.AuthorizeRenewal(crt); err != nil {
|
||||||
if !ok {
|
|
||||||
return &apiError{
|
return &apiError{
|
||||||
err: errors.Errorf("invalid type: provisioner %s, type %T", pid, val),
|
err: err,
|
||||||
code: http.StatusInternalServerError,
|
|
||||||
context: errContext,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if p.Claims.IsDisableRenewal() {
|
|
||||||
return &apiError{
|
|
||||||
err: errors.Errorf("renew disabled: provisioner %s", pid),
|
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
context: errContext,
|
context: errContext,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,100 +7,52 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/cli/crypto/keys"
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
stepJOSE "github.com/smallstep/cli/jose"
|
"github.com/smallstep/cli/jose"
|
||||||
jose "gopkg.in/square/go-jose.v2"
|
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMatchesAudience(t *testing.T) {
|
func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||||
type matchesTest struct {
|
sig, err := jose.NewSigner(
|
||||||
a, b []string
|
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||||
exp bool
|
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||||
}
|
)
|
||||||
tests := map[string]matchesTest{
|
if err != nil {
|
||||||
"false arg1 empty": {
|
return "", err
|
||||||
a: []string{},
|
|
||||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
||||||
exp: false,
|
|
||||||
},
|
|
||||||
"false arg2 empty": {
|
|
||||||
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
||||||
b: []string{},
|
|
||||||
exp: false,
|
|
||||||
},
|
|
||||||
"false arg1,arg2 empty": {
|
|
||||||
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
||||||
b: []string{"step-gateway", "step-cli"},
|
|
||||||
exp: false,
|
|
||||||
},
|
|
||||||
"false": {
|
|
||||||
a: []string{"step-gateway", "step-cli"},
|
|
||||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
||||||
exp: false,
|
|
||||||
},
|
|
||||||
"true": {
|
|
||||||
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
|
||||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
||||||
exp: true,
|
|
||||||
},
|
|
||||||
"true,portsA": {
|
|
||||||
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
|
||||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
|
||||||
exp: true,
|
|
||||||
},
|
|
||||||
"true,portsB": {
|
|
||||||
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
|
||||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"},
|
|
||||||
exp: true,
|
|
||||||
},
|
|
||||||
"true,portsAB": {
|
|
||||||
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
|
||||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"},
|
|
||||||
exp: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStripPort(t *testing.T) {
|
id, err := randutil.ASCII(64)
|
||||||
type args struct {
|
if err != nil {
|
||||||
rawurl string
|
return "", err
|
||||||
}
|
}
|
||||||
tests := []struct {
|
|
||||||
name string
|
claims := struct {
|
||||||
args args
|
jose.Claims
|
||||||
want string
|
SANS []string `json:"sans"`
|
||||||
}{
|
}{
|
||||||
{"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"},
|
Claims: jose.Claims{
|
||||||
{"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"},
|
ID: id,
|
||||||
{"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"},
|
Subject: sub,
|
||||||
}
|
Issuer: iss,
|
||||||
for _, tt := range tests {
|
IssuedAt: jose.NewNumericDate(iat),
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
NotBefore: jose.NewNumericDate(iat),
|
||||||
if got := stripPort(tt.args.rawurl); got != tt.want {
|
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
|
||||||
t.Errorf("stripPort() = %v, want %v", got, tt.want)
|
Audience: []string{aud},
|
||||||
}
|
},
|
||||||
})
|
SANS: sans,
|
||||||
}
|
}
|
||||||
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthorize(t *testing.T) {
|
func TestAuthorize(t *testing.T) {
|
||||||
a := testAuthority(t)
|
a := testAuthority(t)
|
||||||
jwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_priv.jwk",
|
|
||||||
stepJOSE.WithPassword([]byte("pass")))
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||||
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
|
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
// Invalid keys
|
||||||
|
keyNoKid := &jose.JSONWebKey{Key: key.Key, KeyID: ""}
|
||||||
|
keyBadKid := &jose.JSONWebKey{Key: key.Key, KeyID: "foo"}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
validIssuer := "step-cli"
|
validIssuer := "step-cli"
|
||||||
validAudience := []string{"https://test.ca.smallstep.com/sign"}
|
validAudience := []string{"https://test.ca.smallstep.com/sign"}
|
||||||
|
|
||||||
|
@ -120,100 +72,37 @@ func TestAuthorize(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail empty key id": func(t *testing.T) *authorizeTest {
|
"fail empty key id": func(t *testing.T) *authorizeTest {
|
||||||
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyNoKid)
|
||||||
(&jose.SignerOptions{}).WithType("JWT"))
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
cl := jwt.Claims{
|
|
||||||
Subject: "test.smallstep.com",
|
|
||||||
Issuer: validIssuer,
|
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
|
||||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
|
||||||
Audience: validAudience,
|
|
||||||
ID: "43",
|
|
||||||
}
|
|
||||||
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return &authorizeTest{
|
return &authorizeTest{
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorize: token KeyID cannot be empty"),
|
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, context{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail provisioner not found": func(t *testing.T) *authorizeTest {
|
"fail provisioner not found": func(t *testing.T) *authorizeTest {
|
||||||
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyBadKid)
|
||||||
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo"))
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
cl := jwt.Claims{
|
|
||||||
Subject: "test.smallstep.com",
|
|
||||||
Issuer: validIssuer,
|
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
|
||||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
|
||||||
Audience: validAudience,
|
|
||||||
ID: "43",
|
|
||||||
}
|
|
||||||
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return &authorizeTest{
|
return &authorizeTest{
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorize: provisioner with id step-cli:foo not found"),
|
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, context{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail invalid provisioner": func(t *testing.T) *authorizeTest {
|
|
||||||
_a := testAuthority(t)
|
|
||||||
|
|
||||||
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
|
||||||
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo"))
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
_a.provisionerIDIndex.Store(validIssuer+":foo", "42")
|
|
||||||
|
|
||||||
cl := jwt.Claims{
|
|
||||||
Subject: "test.smallstep.com",
|
|
||||||
Issuer: validIssuer,
|
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
|
||||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
|
||||||
Audience: validAudience,
|
|
||||||
ID: "43",
|
|
||||||
}
|
|
||||||
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return &authorizeTest{
|
|
||||||
auth: _a,
|
|
||||||
ott: raw,
|
|
||||||
err: &apiError{errors.New("authorize: invalid provisioner type"),
|
|
||||||
http.StatusInternalServerError, context{"ott": raw}},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail invalid issuer": func(t *testing.T) *authorizeTest {
|
"fail invalid issuer": func(t *testing.T) *authorizeTest {
|
||||||
cl := jwt.Claims{
|
raw, err := generateToken("test.smallstep.com", "invalid-issuer", validAudience[0], nil, now, key)
|
||||||
Subject: "subject",
|
|
||||||
Issuer: "invalid-issuer",
|
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
|
||||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
|
||||||
Audience: validAudience,
|
|
||||||
}
|
|
||||||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return &authorizeTest{
|
return &authorizeTest{
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("authorize: provisioner with id invalid-issuer:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc not found"),
|
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, context{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail empty subject": func(t *testing.T) *authorizeTest {
|
"fail empty subject": func(t *testing.T) *authorizeTest {
|
||||||
cl := jwt.Claims{
|
raw, err := generateToken("", validIssuer, validAudience[0], nil, now, key)
|
||||||
Subject: "",
|
|
||||||
Issuer: validIssuer,
|
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
|
||||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
|
||||||
Audience: validAudience,
|
|
||||||
}
|
|
||||||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return &authorizeTest{
|
return &authorizeTest{
|
||||||
auth: a,
|
auth: a,
|
||||||
|
@ -223,64 +112,34 @@ func TestAuthorize(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail verify-sig-failure": func(t *testing.T) *authorizeTest {
|
"fail verify-sig-failure": func(t *testing.T) *authorizeTest {
|
||||||
_, priv2, err := keys.GenerateDefaultKeyPair()
|
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
|
||||||
assert.FatalError(t, err)
|
|
||||||
invalidKeySig, err := jose.NewSigner(jose.SigningKey{
|
|
||||||
Algorithm: jose.ES256,
|
|
||||||
Key: priv2,
|
|
||||||
}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
cl := jwt.Claims{
|
|
||||||
Subject: "test.smallstep.com",
|
|
||||||
Issuer: validIssuer,
|
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
|
||||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
|
||||||
Audience: validAudience,
|
|
||||||
}
|
|
||||||
raw, err := jwt.Signed(invalidKeySig).Claims(cl).CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return &authorizeTest{
|
return &authorizeTest{
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw + "00",
|
||||||
err: &apiError{errors.New("square/go-jose: error in cryptographic primitive"),
|
err: &apiError{errors.New("authorize: error parsing claims: square/go-jose: error in cryptographic primitive"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, context{"ott": raw + "00"}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail token-already-used": func(t *testing.T) *authorizeTest {
|
"fail token-already-used": func(t *testing.T) *authorizeTest {
|
||||||
cl := jwt.Claims{
|
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
|
||||||
Subject: "test.smallstep.com",
|
|
||||||
Issuer: validIssuer,
|
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
|
||||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
|
||||||
Audience: validAudience,
|
|
||||||
ID: "42",
|
|
||||||
}
|
|
||||||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
_, err = a.Authorize(raw)
|
_, err = a.Authorize(raw)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return &authorizeTest{
|
return &authorizeTest{
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
err: &apiError{errors.New("token already used"),
|
err: &apiError{errors.New("authorize: token already used"),
|
||||||
http.StatusUnauthorized, context{"ott": raw}},
|
http.StatusUnauthorized, context{"ott": raw}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) *authorizeTest {
|
"ok": func(t *testing.T) *authorizeTest {
|
||||||
cl := jwt.Claims{
|
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
|
||||||
Subject: "test.smallstep.com",
|
|
||||||
Issuer: validIssuer,
|
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
|
||||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
|
||||||
Audience: validAudience,
|
|
||||||
ID: "43",
|
|
||||||
}
|
|
||||||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return &authorizeTest{
|
return &authorizeTest{
|
||||||
auth: a,
|
auth: a,
|
||||||
ott: raw,
|
ott: raw,
|
||||||
res: []interface{}{"1", "2", "3", "4"},
|
res: []interface{}{"1", "2", "3", "4", "5", "6"},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,117 +0,0 @@
|
||||||
package authority
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
x509 "github.com/smallstep/cli/pkg/x509"
|
|
||||||
)
|
|
||||||
|
|
||||||
// certClaim interface is implemented by types used to validate specific claims in a
|
|
||||||
// certificate request.
|
|
||||||
type certClaim interface {
|
|
||||||
Valid(crt *x509.Certificate) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateClaims returns nil if all the claims are validated, it will return
|
|
||||||
// the first error if a claim fails.
|
|
||||||
func validateClaims(crt *x509.Certificate, claims []certClaim) (err error) {
|
|
||||||
for _, c := range claims {
|
|
||||||
if err = c.Valid(crt); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// commonNameClaim validates the common name of a certificate request.
|
|
||||||
type commonNameClaim struct {
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Valid checks that certificate request common name matches the one configured.
|
|
||||||
func (c *commonNameClaim) Valid(crt *x509.Certificate) error {
|
|
||||||
if crt.Subject.CommonName == "" {
|
|
||||||
return errors.New("common name cannot be empty")
|
|
||||||
}
|
|
||||||
if crt.Subject.CommonName != c.name {
|
|
||||||
return errors.Errorf("common name claim failed - got %s, want %s", crt.Subject.CommonName, c.name)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type dnsNamesClaim struct {
|
|
||||||
names []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Valid checks that certificate request DNS Names match those configured in
|
|
||||||
// the bootstrap (token) flow.
|
|
||||||
func (c *dnsNamesClaim) Valid(crt *x509.Certificate) error {
|
|
||||||
tokMap := make(map[string]int)
|
|
||||||
for _, e := range c.names {
|
|
||||||
tokMap[e] = 1
|
|
||||||
}
|
|
||||||
crtMap := make(map[string]int)
|
|
||||||
for _, e := range crt.DNSNames {
|
|
||||||
crtMap[e] = 1
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(tokMap, crtMap) {
|
|
||||||
return errors.Errorf("DNS names claim failed - got %s, want %s", crt.DNSNames, c.names)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ipAddressesClaim struct {
|
|
||||||
ips []net.IP
|
|
||||||
}
|
|
||||||
|
|
||||||
// Valid checks that certificate request IP Addresses match those configured in
|
|
||||||
// the bootstrap (token) flow.
|
|
||||||
func (c *ipAddressesClaim) Valid(crt *x509.Certificate) error {
|
|
||||||
tokMap := make(map[string]int)
|
|
||||||
for _, e := range c.ips {
|
|
||||||
tokMap[e.String()] = 1
|
|
||||||
}
|
|
||||||
crtMap := make(map[string]int)
|
|
||||||
for _, e := range crt.IPAddresses {
|
|
||||||
crtMap[e.String()] = 1
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(tokMap, crtMap) {
|
|
||||||
return errors.Errorf("IP Addresses claim failed - got %v, want %v", crt.IPAddresses, c.ips)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// certTemporalClaim validates the certificate temporal validity settings.
|
|
||||||
type certTemporalClaim struct {
|
|
||||||
min time.Duration
|
|
||||||
max time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate validates the certificate temporal validity settings.
|
|
||||||
func (ctc *certTemporalClaim) Valid(crt *x509.Certificate) error {
|
|
||||||
var (
|
|
||||||
na = crt.NotAfter
|
|
||||||
nb = crt.NotBefore
|
|
||||||
d = na.Sub(nb)
|
|
||||||
now = time.Now()
|
|
||||||
)
|
|
||||||
|
|
||||||
if na.Before(now) {
|
|
||||||
return errors.Errorf("NotAfter: %v cannot be in the past", na)
|
|
||||||
}
|
|
||||||
if na.Before(nb) {
|
|
||||||
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
|
|
||||||
}
|
|
||||||
if d < ctc.min {
|
|
||||||
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
|
|
||||||
d, ctc.min)
|
|
||||||
}
|
|
||||||
if d > ctc.max {
|
|
||||||
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
|
|
||||||
d, ctc.max)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,132 +0,0 @@
|
||||||
package authority
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
x509 "github.com/smallstep/cli/pkg/x509"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCommonNameClaim_Valid(t *testing.T) {
|
|
||||||
tests := map[string]struct {
|
|
||||||
cnc certClaim
|
|
||||||
crt *x509.Certificate
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"empty-common-name": {
|
|
||||||
cnc: &commonNameClaim{name: "foo"},
|
|
||||||
crt: &x509.Certificate{},
|
|
||||||
err: errors.New("common name cannot be empty"),
|
|
||||||
},
|
|
||||||
"wrong-common-name": {
|
|
||||||
cnc: &commonNameClaim{name: "foo"},
|
|
||||||
crt: &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}},
|
|
||||||
err: errors.New("common name claim failed - got bar, want foo"),
|
|
||||||
},
|
|
||||||
"ok": {
|
|
||||||
cnc: &commonNameClaim{name: "foo"},
|
|
||||||
crt: &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
err := tc.cnc.Valid(tc.crt)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
assert.Equals(t, tc.err.Error(), err.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert.Nil(t, tc.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIPAddressesClaim_Valid(t *testing.T) {
|
|
||||||
tests := map[string]struct {
|
|
||||||
iac certClaim
|
|
||||||
crt *x509.Certificate
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"unexpected-ip-in-crt": {
|
|
||||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
|
|
||||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}},
|
|
||||||
err: errors.New("IP Addresses claim failed - got [127.0.0.1 1.1.1.1], want [127.0.0.1]"),
|
|
||||||
},
|
|
||||||
"missing-ip-in-crt": {
|
|
||||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}},
|
|
||||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
|
|
||||||
err: errors.New("IP Addresses claim failed - got [127.0.0.1], want [127.0.0.1 1.1.1.1]"),
|
|
||||||
},
|
|
||||||
"invalid-matcher-nonempty-ips": {
|
|
||||||
iac: &ipAddressesClaim{ips: []net.IP{}},
|
|
||||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
|
|
||||||
err: errors.New("IP Addresses claim failed - got [127.0.0.1], want []"),
|
|
||||||
},
|
|
||||||
"ok": {
|
|
||||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
|
|
||||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
|
|
||||||
},
|
|
||||||
"ok-multiple-identical-ip-entries": {
|
|
||||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
|
|
||||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1")}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
err := tc.iac.Valid(tc.crt)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
assert.Equals(t, tc.err.Error(), err.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert.Nil(t, tc.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSNamesClaim_Valid(t *testing.T) {
|
|
||||||
tests := map[string]struct {
|
|
||||||
dnc certClaim
|
|
||||||
crt *x509.Certificate
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"unexpected-dns-name-in-crt": {
|
|
||||||
dnc: &dnsNamesClaim{names: []string{"foo"}},
|
|
||||||
crt: &x509.Certificate{DNSNames: []string{"foo", "bar"}},
|
|
||||||
err: errors.New("DNS names claim failed - got [foo bar], want [foo]"),
|
|
||||||
},
|
|
||||||
"ok": {
|
|
||||||
dnc: &dnsNamesClaim{names: []string{"foo", "bar"}},
|
|
||||||
crt: &x509.Certificate{DNSNames: []string{"bar", "foo"}},
|
|
||||||
},
|
|
||||||
"missing-dns-name-in-crt": {
|
|
||||||
dnc: &dnsNamesClaim{names: []string{"foo", "bar"}},
|
|
||||||
crt: &x509.Certificate{DNSNames: []string{"foo"}},
|
|
||||||
err: errors.New("DNS names claim failed - got [foo], want [foo bar]"),
|
|
||||||
},
|
|
||||||
"ok-multiple-identical-dns-entries": {
|
|
||||||
dnc: &dnsNamesClaim{names: []string{"foo"}},
|
|
||||||
crt: &x509.Certificate{DNSNames: []string{"foo", "foo", "foo"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
err := tc.dnc.Valid(tc.crt)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
assert.Equals(t, tc.err.Error(), err.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert.Nil(t, tc.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -2,11 +2,13 @@ package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/cli/crypto/tlsutil"
|
"github.com/smallstep/cli/crypto/tlsutil"
|
||||||
"github.com/smallstep/cli/crypto/x509util"
|
"github.com/smallstep/cli/crypto/x509util"
|
||||||
)
|
)
|
||||||
|
@ -25,10 +27,10 @@ var (
|
||||||
Renegotiation: false,
|
Renegotiation: false,
|
||||||
}
|
}
|
||||||
defaultDisableRenewal = false
|
defaultDisableRenewal = false
|
||||||
globalProvisionerClaims = ProvisionerClaims{
|
globalProvisionerClaims = provisioner.Claims{
|
||||||
MinTLSDur: &Duration{5 * time.Minute},
|
MinTLSDur: &provisioner.Duration{5 * time.Minute},
|
||||||
MaxTLSDur: &Duration{24 * time.Hour},
|
MaxTLSDur: &provisioner.Duration{24 * time.Hour},
|
||||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
DefaultTLSDur: &provisioner.Duration{24 * time.Hour},
|
||||||
DisableRenewal: &defaultDisableRenewal,
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -50,16 +52,15 @@ type Config struct {
|
||||||
|
|
||||||
// AuthConfig represents the configuration options for the authority.
|
// AuthConfig represents the configuration options for the authority.
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
Provisioners []*Provisioner `json:"provisioners,omitempty"`
|
Provisioners provisioner.List `json:"provisioners"`
|
||||||
Template *x509util.ASN1DN `json:"template,omitempty"`
|
Template *x509util.ASN1DN `json:"template,omitempty"`
|
||||||
Claims *ProvisionerClaims `json:"claims,omitempty"`
|
Claims *provisioner.Claims `json:"claims,omitempty"`
|
||||||
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates the authority configuration.
|
// Validate validates the authority configuration.
|
||||||
func (c *AuthConfig) Validate() error {
|
func (c *AuthConfig) Validate(audiences []string) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return errors.New("authority cannot be undefined")
|
return errors.New("authority cannot be undefined")
|
||||||
}
|
}
|
||||||
|
@ -70,11 +71,18 @@ func (c *AuthConfig) Validate() error {
|
||||||
if c.Claims, err = c.Claims.Init(&globalProvisionerClaims); err != nil {
|
if c.Claims, err = c.Claims.Init(&globalProvisionerClaims); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize provisioners
|
||||||
|
config := provisioner.Config{
|
||||||
|
Claims: *c.Claims,
|
||||||
|
Audiences: audiences,
|
||||||
|
}
|
||||||
for _, p := range c.Provisioners {
|
for _, p := range c.Provisioners {
|
||||||
if err := p.Init(c.Claims); err != nil {
|
if err := p.Init(config); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Template == nil {
|
if c.Template == nil {
|
||||||
c.Template = &x509util.ASN1DN{}
|
c.Template = &x509util.ASN1DN{}
|
||||||
}
|
}
|
||||||
|
@ -153,5 +161,16 @@ func (c *Config) Validate() error {
|
||||||
c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation
|
c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.AuthorityConfig.Validate()
|
return c.AuthorityConfig.Validate(c.getAudiences())
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAudiences returns the legacy and possible urls without the ports that will
|
||||||
|
// be used as the default provisioner audiences. The CA might have proxies in
|
||||||
|
// front so we cannot rely on the port.
|
||||||
|
func (c *Config) getAudiences() []string {
|
||||||
|
audiences := []string{legacyAuthority}
|
||||||
|
for _, name := range c.DNSNames {
|
||||||
|
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
|
||||||
|
}
|
||||||
|
return audiences
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/cli/crypto/tlsutil"
|
"github.com/smallstep/cli/crypto/tlsutil"
|
||||||
"github.com/smallstep/cli/crypto/x509util"
|
"github.com/smallstep/cli/crypto/x509util"
|
||||||
stepJOSE "github.com/smallstep/cli/jose"
|
stepJOSE "github.com/smallstep/cli/jose"
|
||||||
|
@ -17,13 +18,13 @@ func TestConfigValidate(t *testing.T) {
|
||||||
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ac := &AuthConfig{
|
ac := &AuthConfig{
|
||||||
Provisioners: []*Provisioner{
|
Provisioners: provisioner.List{
|
||||||
{
|
&provisioner.JWK{
|
||||||
Name: "Max",
|
Name: "Max",
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Key: maxjwk,
|
Key: maxjwk,
|
||||||
},
|
},
|
||||||
{
|
&provisioner.JWK{
|
||||||
Name: "step-cli",
|
Name: "step-cli",
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Key: clijwk,
|
Key: clijwk,
|
||||||
|
@ -229,13 +230,13 @@ func TestAuthConfigValidate(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p := []*Provisioner{
|
p := provisioner.List{
|
||||||
{
|
&provisioner.JWK{
|
||||||
Name: "Max",
|
Name: "Max",
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Key: maxjwk,
|
Key: maxjwk,
|
||||||
},
|
},
|
||||||
{
|
&provisioner.JWK{
|
||||||
Name: "step-cli",
|
Name: "step-cli",
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Key: clijwk,
|
Key: clijwk,
|
||||||
|
@ -263,9 +264,9 @@ func TestAuthConfigValidate(t *testing.T) {
|
||||||
"fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest {
|
"fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest {
|
||||||
return AuthConfigValidateTest{
|
return AuthConfigValidateTest{
|
||||||
ac: &AuthConfig{
|
ac: &AuthConfig{
|
||||||
Provisioners: []*Provisioner{
|
Provisioners: provisioner.List{
|
||||||
{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
&provisioner.JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
||||||
{Name: "foo", Key: &jose.JSONWebKey{}},
|
&provisioner.JWK{Name: "foo", Key: &jose.JSONWebKey{}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
err: errors.New("provisioner type cannot be empty"),
|
err: errors.New("provisioner type cannot be empty"),
|
||||||
|
@ -293,7 +294,7 @@ func TestAuthConfigValidate(t *testing.T) {
|
||||||
for name, get := range tests {
|
for name, get := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := get(t)
|
tc := get(t)
|
||||||
err := tc.ac.Validate()
|
err := tc.ac.Validate([]string{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.Equals(t, tc.err.Error(), err.Error())
|
assert.Equals(t, tc.err.Error(), err.Error())
|
||||||
|
|
|
@ -1,17 +1,14 @@
|
||||||
package authority
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/cli/crypto/x509util"
|
|
||||||
|
|
||||||
jose "gopkg.in/square/go-jose.v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProvisionerClaims so that individual provisioners can override global claims.
|
// Claims so that individual provisioners can override global claims.
|
||||||
type ProvisionerClaims struct {
|
type Claims struct {
|
||||||
globalClaims *ProvisionerClaims
|
globalClaims *Claims
|
||||||
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
||||||
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
||||||
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
||||||
|
@ -19,19 +16,18 @@ type ProvisionerClaims struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes and validates the individual provisioner claims.
|
// Init initializes and validates the individual provisioner claims.
|
||||||
func (pc *ProvisionerClaims) Init(global *ProvisionerClaims) (*ProvisionerClaims, error) {
|
func (pc *Claims) Init(global *Claims) (*Claims, error) {
|
||||||
if pc == nil {
|
if pc == nil {
|
||||||
pc = &ProvisionerClaims{}
|
pc = &Claims{}
|
||||||
}
|
}
|
||||||
pc.globalClaims = global
|
pc.globalClaims = global
|
||||||
err := pc.Validate()
|
return pc, pc.Validate()
|
||||||
return pc, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultTLSCertDuration returns the default TLS cert duration for the
|
// DefaultTLSCertDuration returns the default TLS cert duration for the
|
||||||
// provisioner. If the default is not set within the provisioner, then the global
|
// provisioner. If the default is not set within the provisioner, then the global
|
||||||
// default from the authority configuration will be used.
|
// default from the authority configuration will be used.
|
||||||
func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration {
|
func (pc *Claims) DefaultTLSCertDuration() time.Duration {
|
||||||
if pc.DefaultTLSDur == nil || pc.DefaultTLSDur.Duration == 0 {
|
if pc.DefaultTLSDur == nil || pc.DefaultTLSDur.Duration == 0 {
|
||||||
return pc.globalClaims.DefaultTLSCertDuration()
|
return pc.globalClaims.DefaultTLSCertDuration()
|
||||||
}
|
}
|
||||||
|
@ -41,7 +37,7 @@ func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration {
|
||||||
// MinTLSCertDuration returns the minimum TLS cert duration for the provisioner.
|
// MinTLSCertDuration returns the minimum TLS cert duration for the provisioner.
|
||||||
// If the minimum is not set within the provisioner, then the global
|
// If the minimum is not set within the provisioner, then the global
|
||||||
// minimum from the authority configuration will be used.
|
// minimum from the authority configuration will be used.
|
||||||
func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration {
|
func (pc *Claims) MinTLSCertDuration() time.Duration {
|
||||||
if pc.MinTLSDur == nil || pc.MinTLSDur.Duration == 0 {
|
if pc.MinTLSDur == nil || pc.MinTLSDur.Duration == 0 {
|
||||||
return pc.globalClaims.MinTLSCertDuration()
|
return pc.globalClaims.MinTLSCertDuration()
|
||||||
}
|
}
|
||||||
|
@ -51,7 +47,7 @@ func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration {
|
||||||
// MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner.
|
// MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner.
|
||||||
// If the maximum is not set within the provisioner, then the global
|
// If the maximum is not set within the provisioner, then the global
|
||||||
// maximum from the authority configuration will be used.
|
// maximum from the authority configuration will be used.
|
||||||
func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration {
|
func (pc *Claims) MaxTLSCertDuration() time.Duration {
|
||||||
if pc.MaxTLSDur == nil || pc.MaxTLSDur.Duration == 0 {
|
if pc.MaxTLSDur == nil || pc.MaxTLSDur.Duration == 0 {
|
||||||
return pc.globalClaims.MaxTLSCertDuration()
|
return pc.globalClaims.MaxTLSCertDuration()
|
||||||
}
|
}
|
||||||
|
@ -61,7 +57,7 @@ func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration {
|
||||||
// IsDisableRenewal returns if the renewal flow is disabled for the
|
// IsDisableRenewal returns if the renewal flow is disabled for the
|
||||||
// provisioner. If the property is not set within the provisioner, then the
|
// provisioner. If the property is not set within the provisioner, then the
|
||||||
// global value from the authority configuration will be used.
|
// global value from the authority configuration will be used.
|
||||||
func (pc *ProvisionerClaims) IsDisableRenewal() bool {
|
func (pc *Claims) IsDisableRenewal() bool {
|
||||||
if pc.DisableRenewal == nil {
|
if pc.DisableRenewal == nil {
|
||||||
return pc.globalClaims.IsDisableRenewal()
|
return pc.globalClaims.IsDisableRenewal()
|
||||||
}
|
}
|
||||||
|
@ -69,7 +65,7 @@ func (pc *ProvisionerClaims) IsDisableRenewal() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates and modifies the Claims with default values.
|
// Validate validates and modifies the Claims with default values.
|
||||||
func (pc *ProvisionerClaims) Validate() error {
|
func (pc *Claims) Validate() error {
|
||||||
var (
|
var (
|
||||||
min = pc.MinTLSCertDuration()
|
min = pc.MinTLSCertDuration()
|
||||||
max = pc.MaxTLSCertDuration()
|
max = pc.MaxTLSCertDuration()
|
||||||
|
@ -93,52 +89,3 @@ func (pc *ProvisionerClaims) Validate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Provisioner - authorized entity that can sign tokens necessary for signature requests.
|
|
||||||
type Provisioner struct {
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
Key *jose.JSONWebKey `json:"key,omitempty"`
|
|
||||||
EncryptedKey string `json:"encryptedKey,omitempty"`
|
|
||||||
Claims *ProvisionerClaims `json:"claims,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Init initializes and validates a the fields of Provisioner type.
|
|
||||||
func (p *Provisioner) Init(global *ProvisionerClaims) error {
|
|
||||||
switch {
|
|
||||||
case p.Name == "":
|
|
||||||
return errors.New("provisioner name cannot be empty")
|
|
||||||
|
|
||||||
case p.Type == "":
|
|
||||||
return errors.New("provisioner type cannot be empty")
|
|
||||||
|
|
||||||
case p.Key == nil:
|
|
||||||
return errors.New("provisioner key cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
p.Claims, err = p.Claims.Init(global)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// getTLSApps returns a list of modifiers and validators that will be applied to
|
|
||||||
// the certificate.
|
|
||||||
func (p *Provisioner) getTLSApps(so SignOptions) ([]x509util.WithOption, []certClaim, error) {
|
|
||||||
c := p.Claims
|
|
||||||
return []x509util.WithOption{
|
|
||||||
x509util.WithNotBeforeAfterDuration(so.NotBefore,
|
|
||||||
so.NotAfter, c.DefaultTLSCertDuration()),
|
|
||||||
withProvisionerOID(p.Name, p.Key.KeyID),
|
|
||||||
}, []certClaim{
|
|
||||||
&certTemporalClaim{
|
|
||||||
min: c.MinTLSCertDuration(),
|
|
||||||
max: c.MaxTLSCertDuration(),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ID returns the provisioner identifier. The name and credential id should
|
|
||||||
// uniquely identify any provisioner.
|
|
||||||
func (p *Provisioner) ID() string {
|
|
||||||
return p.Name + ":" + p.Key.KeyID
|
|
||||||
}
|
|
212
authority/provisioner/collection.go
Normal file
212
authority/provisioner/collection.go
Normal file
|
@ -0,0 +1,212 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/asn1"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||||
|
const DefaultProvisionersLimit = 20
|
||||||
|
|
||||||
|
// DefaultProvisionersMax is the maximum limit for listing provisioners.
|
||||||
|
const DefaultProvisionersMax = 100
|
||||||
|
|
||||||
|
type uidProvisioner struct {
|
||||||
|
provisioner Interface
|
||||||
|
uid string
|
||||||
|
}
|
||||||
|
|
||||||
|
type provisionerSlice []uidProvisioner
|
||||||
|
|
||||||
|
func (p provisionerSlice) Len() int { return len(p) }
|
||||||
|
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||||
|
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||||
|
|
||||||
|
// Collection is a memory map of provisioners.
|
||||||
|
type Collection struct {
|
||||||
|
byID *sync.Map
|
||||||
|
byKey *sync.Map
|
||||||
|
sorted provisionerSlice
|
||||||
|
audiences []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCollection initializes a collection of provisioners. The given list of
|
||||||
|
// audiences are the audiences used by the JWT provisioner.
|
||||||
|
func NewCollection(audiences []string) *Collection {
|
||||||
|
return &Collection{
|
||||||
|
byID: new(sync.Map),
|
||||||
|
byKey: new(sync.Map),
|
||||||
|
audiences: audiences,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load a provisioner by the ID.
|
||||||
|
func (c *Collection) Load(id string) (Interface, bool) {
|
||||||
|
return loadProvisioner(c.byID, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadByToken parses the token claims and loads the provisioner associated.
|
||||||
|
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) {
|
||||||
|
// match with server audiences
|
||||||
|
if matchesAudience(claims.Audience, c.audiences) {
|
||||||
|
// If matches with stored audiences it will be a JWT token (default), and
|
||||||
|
// the id would be <issuer>:<kid>.
|
||||||
|
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ID will be just the clientID stored in azp or aud.
|
||||||
|
var payload openIDPayload
|
||||||
|
if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// audience is required
|
||||||
|
if len(payload.Audience) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if len(payload.AuthorizedParty) > 0 {
|
||||||
|
return c.Load(payload.AuthorizedParty)
|
||||||
|
}
|
||||||
|
return c.Load(payload.Audience[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadByCertificate looks for the provisioner extension and extracts the
|
||||||
|
// proper id to load the provisioner.
|
||||||
|
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
|
||||||
|
for _, e := range cert.Extensions {
|
||||||
|
if e.Id.Equal(stepOIDProvisioner) {
|
||||||
|
var provisioner stepProvisionerASN1
|
||||||
|
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if provisioner.Type == int(TypeJWK) {
|
||||||
|
return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID))
|
||||||
|
}
|
||||||
|
return c.Load(string(provisioner.CredentialID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to noop provisioner if an extension is not found. This allows to
|
||||||
|
// accept a renewal of a cert without the provisioner extension.
|
||||||
|
return &noop{}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadEncryptedKey returns an encrypted key by indexed by KeyID. At this moment
|
||||||
|
// only JWK encrypted keys are indexed by KeyID.
|
||||||
|
func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
|
||||||
|
p, ok := loadProvisioner(c.byKey, keyID)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
_, key, ok := p.GetEncryptedKey()
|
||||||
|
return key, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store adds a provisioner to the collection and enforces the uniqueness of
|
||||||
|
// provisioner IDs.
|
||||||
|
func (c *Collection) Store(p Interface) error {
|
||||||
|
// Store provisioner always in byID. ID must be unique.
|
||||||
|
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true {
|
||||||
|
return errors.New("cannot add multiple provisioners with the same id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store provisioner in byKey if EncryptedKey is defined.
|
||||||
|
if kid, _, ok := p.GetEncryptedKey(); ok {
|
||||||
|
c.byKey.Store(kid, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store sorted provisioners.
|
||||||
|
// Use the first 4 bytes (32bit) of the sum to insert the order
|
||||||
|
// Using big endian format to get the strings sorted:
|
||||||
|
// 0x00000000, 0x00000001, 0x00000002, ...
|
||||||
|
bi := make([]byte, 4)
|
||||||
|
sum := provisionerSum(p)
|
||||||
|
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
|
||||||
|
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
||||||
|
c.sorted = append(c.sorted, uidProvisioner{
|
||||||
|
provisioner: p,
|
||||||
|
uid: hex.EncodeToString(sum),
|
||||||
|
})
|
||||||
|
sort.Sort(c.sorted)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find implements pagination on a list of sorted provisioners.
|
||||||
|
func (c *Collection) Find(cursor string, limit int) (List, string) {
|
||||||
|
switch {
|
||||||
|
case limit <= 0:
|
||||||
|
limit = DefaultProvisionersLimit
|
||||||
|
case limit > DefaultProvisionersMax:
|
||||||
|
limit = DefaultProvisionersMax
|
||||||
|
}
|
||||||
|
|
||||||
|
n := c.sorted.Len()
|
||||||
|
cursor = fmt.Sprintf("%040s", cursor)
|
||||||
|
i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor })
|
||||||
|
|
||||||
|
slice := List{}
|
||||||
|
for ; i < n && len(slice) < limit; i++ {
|
||||||
|
slice = append(slice, c.sorted[i].provisioner)
|
||||||
|
}
|
||||||
|
|
||||||
|
if i < n {
|
||||||
|
return slice, strings.TrimLeft(c.sorted[i].uid, "0")
|
||||||
|
}
|
||||||
|
return slice, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadProvisioner(m *sync.Map, key string) (Interface, bool) {
|
||||||
|
i, ok := m.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
p, ok := i.(Interface)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// provisionerSum returns the SHA1 of the provisioners ID. From this we will
|
||||||
|
// create the unique and sorted id.
|
||||||
|
func provisionerSum(p Interface) []byte {
|
||||||
|
sum := sha1.Sum([]byte(p.GetID()))
|
||||||
|
return sum[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesAudience returns true if A and B share at least one element.
|
||||||
|
func matchesAudience(as, bs []string) bool {
|
||||||
|
if len(bs) == 0 || len(as) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range bs {
|
||||||
|
for _, a := range as {
|
||||||
|
if b == a || stripPort(a) == stripPort(b) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripPort attempts to strip the port from the given url. If parsing the url
|
||||||
|
// produces errors it will just return the passed argument.
|
||||||
|
func stripPort(rawurl string) string {
|
||||||
|
u, err := url.Parse(rawurl)
|
||||||
|
if err != nil {
|
||||||
|
return rawurl
|
||||||
|
}
|
||||||
|
u.Host = u.Hostname()
|
||||||
|
return u.String()
|
||||||
|
}
|
390
authority/provisioner/collection_test.go
Normal file
390
authority/provisioner/collection_test.go
Normal file
|
@ -0,0 +1,390 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCollection_Load(t *testing.T) {
|
||||||
|
p, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
byID := new(sync.Map)
|
||||||
|
byID.Store(p.GetID(), p)
|
||||||
|
byID.Store("string", "a-string")
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
byID *sync.Map
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want Interface
|
||||||
|
want1 bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{byID}, args{p.GetID()}, p, true},
|
||||||
|
{"fail", fields{byID}, args{"fail"}, nil, false},
|
||||||
|
{"invalid", fields{byID}, args{"string"}, nil, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Collection{
|
||||||
|
byID: tt.fields.byID,
|
||||||
|
}
|
||||||
|
got, got1 := c.Load(tt.args.id)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Collection.Load() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got1 != tt.want1 {
|
||||||
|
t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollection_LoadByToken(t *testing.T) {
|
||||||
|
p1, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p3, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
byID := new(sync.Map)
|
||||||
|
byID.Store(p1.GetID(), p1)
|
||||||
|
byID.Store(p2.GetID(), p2)
|
||||||
|
byID.Store(p3.GetID(), p3)
|
||||||
|
byID.Store("string", "a-string")
|
||||||
|
|
||||||
|
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
token, err := generateSimpleToken(p1.Name, testAudiences[0], jwk)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t1, c1, err := parseToken(token)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
jwk, err = decryptJSONWebKey(p2.EncryptedKey)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
token, err = generateSimpleToken(p2.Name, testAudiences[1], jwk)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t2, c2, err := parseToken(token)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t3, c3, err := parseToken(token)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t4, c4, err := parseToken(token)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
byID *sync.Map
|
||||||
|
audiences []string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
token *jose.JSONWebToken
|
||||||
|
claims *jose.Claims
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want Interface
|
||||||
|
want1 bool
|
||||||
|
}{
|
||||||
|
{"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true},
|
||||||
|
{"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true},
|
||||||
|
{"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true},
|
||||||
|
{"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false},
|
||||||
|
{"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Collection{
|
||||||
|
byID: tt.fields.byID,
|
||||||
|
audiences: tt.fields.audiences,
|
||||||
|
}
|
||||||
|
got, got1 := c.LoadByToken(tt.args.token, tt.args.claims)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Collection.LoadByToken() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got1 != tt.want1 {
|
||||||
|
t.Errorf("Collection.LoadByToken() got1 = %v, want %v", got1, tt.want1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollection_LoadByCertificate(t *testing.T) {
|
||||||
|
p1, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
byID := new(sync.Map)
|
||||||
|
byID.Store(p1.GetID(), p1)
|
||||||
|
byID.Store(p2.GetID(), p2)
|
||||||
|
|
||||||
|
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
ok1Cert := &x509.Certificate{
|
||||||
|
Extensions: []pkix.Extension{ok1Ext},
|
||||||
|
}
|
||||||
|
ok2Cert := &x509.Certificate{
|
||||||
|
Extensions: []pkix.Extension{ok2Ext},
|
||||||
|
}
|
||||||
|
notFoundCert := &x509.Certificate{
|
||||||
|
Extensions: []pkix.Extension{notFoundExt},
|
||||||
|
}
|
||||||
|
badCert := &x509.Certificate{
|
||||||
|
Extensions: []pkix.Extension{
|
||||||
|
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
byID *sync.Map
|
||||||
|
audiences []string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want Interface
|
||||||
|
want1 bool
|
||||||
|
}{
|
||||||
|
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
|
||||||
|
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
|
||||||
|
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
||||||
|
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
|
||||||
|
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Collection{
|
||||||
|
byID: tt.fields.byID,
|
||||||
|
audiences: tt.fields.audiences,
|
||||||
|
}
|
||||||
|
got, got1 := c.LoadByCertificate(tt.args.cert)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got1 != tt.want1 {
|
||||||
|
t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollection_LoadEncryptedKey(t *testing.T) {
|
||||||
|
c := NewCollection(testAudiences)
|
||||||
|
p1, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.FatalError(t, c.Store(p1))
|
||||||
|
p2, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.FatalError(t, c.Store(p2))
|
||||||
|
|
||||||
|
// Add oidc in byKey.
|
||||||
|
// It should not happen.
|
||||||
|
p2KeyID := p2.keyStore.keySet.Keys[0].KeyID
|
||||||
|
c.byKey.Store(p2KeyID, p2)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
keyID string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
want1 bool
|
||||||
|
}{
|
||||||
|
{"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true},
|
||||||
|
{"oidc", args{p2KeyID}, "", false},
|
||||||
|
{"notFound", args{"not-found"}, "", false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, got1 := c.LoadEncryptedKey(tt.args.keyID)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got1 != tt.want1 {
|
||||||
|
t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollection_Store(t *testing.T) {
|
||||||
|
c := NewCollection(testAudiences)
|
||||||
|
p1, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
p Interface
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok1", args{p1}, false},
|
||||||
|
{"ok2", args{p2}, false},
|
||||||
|
{"fail1", args{p1}, true},
|
||||||
|
{"fail2", args{p2}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := c.Store(tt.args.p); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollection_Find(t *testing.T) {
|
||||||
|
c, err := generateCollection(10, 10)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
trim := func(s string) string {
|
||||||
|
return strings.TrimLeft(s, "0")
|
||||||
|
}
|
||||||
|
toList := func(ps provisionerSlice) List {
|
||||||
|
l := List{}
|
||||||
|
for _, p := range ps {
|
||||||
|
l = append(l, p.provisioner)
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
cursor string
|
||||||
|
limit int
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want List
|
||||||
|
want1 string
|
||||||
|
}{
|
||||||
|
{"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""},
|
||||||
|
{"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""},
|
||||||
|
{"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)},
|
||||||
|
{"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""},
|
||||||
|
{"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)},
|
||||||
|
{"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)},
|
||||||
|
{"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""},
|
||||||
|
{"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, got1 := c.Find(tt.args.cursor, tt.args.limit)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Collection.Find() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got1 != tt.want1 {
|
||||||
|
t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_matchesAudience(t *testing.T) {
|
||||||
|
type matchesTest struct {
|
||||||
|
a, b []string
|
||||||
|
exp bool
|
||||||
|
}
|
||||||
|
tests := map[string]matchesTest{
|
||||||
|
"false arg1 empty": {
|
||||||
|
a: []string{},
|
||||||
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||||
|
exp: false,
|
||||||
|
},
|
||||||
|
"false arg2 empty": {
|
||||||
|
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||||
|
b: []string{},
|
||||||
|
exp: false,
|
||||||
|
},
|
||||||
|
"false arg1,arg2 empty": {
|
||||||
|
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||||
|
b: []string{"step-gateway", "step-cli"},
|
||||||
|
exp: false,
|
||||||
|
},
|
||||||
|
"false": {
|
||||||
|
a: []string{"step-gateway", "step-cli"},
|
||||||
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||||
|
exp: false,
|
||||||
|
},
|
||||||
|
"true": {
|
||||||
|
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
||||||
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||||
|
exp: true,
|
||||||
|
},
|
||||||
|
"true,portsA": {
|
||||||
|
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
||||||
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||||
|
exp: true,
|
||||||
|
},
|
||||||
|
"true,portsB": {
|
||||||
|
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
||||||
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"},
|
||||||
|
exp: true,
|
||||||
|
},
|
||||||
|
"true,portsAB": {
|
||||||
|
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
||||||
|
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"},
|
||||||
|
exp: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_stripPort(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
rawurl string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"},
|
||||||
|
{"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"},
|
||||||
|
{"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := stripPort(tt.args.rawurl); got != tt.want {
|
||||||
|
t.Errorf("stripPort() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
45
authority/provisioner/duration.go
Normal file
45
authority/provisioner/duration.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
|
||||||
|
type Duration struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON parses a duration string and sets it to the duration.
|
||||||
|
//
|
||||||
|
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||||
|
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||||
|
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||||
|
func (d *Duration) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(d.Duration.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON parses a duration string and sets it to the duration.
|
||||||
|
//
|
||||||
|
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||||
|
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||||
|
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||||
|
func (d *Duration) UnmarshalJSON(data []byte) (err error) {
|
||||||
|
var (
|
||||||
|
s string
|
||||||
|
_d time.Duration
|
||||||
|
)
|
||||||
|
if d == nil {
|
||||||
|
return errors.New("duration cannot be nil")
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(data, &s); err != nil {
|
||||||
|
return errors.Wrapf(err, "error unmarshaling %s", data)
|
||||||
|
}
|
||||||
|
if _d, err = time.ParseDuration(s); err != nil {
|
||||||
|
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||||
|
}
|
||||||
|
d.Duration = _d
|
||||||
|
return
|
||||||
|
}
|
61
authority/provisioner/duration_test.go
Normal file
61
authority/provisioner/duration_test.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDuration_UnmarshalJSON(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
d *Duration
|
||||||
|
args args
|
||||||
|
want *Duration
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"empty", new(Duration), args{[]byte{}}, new(Duration), true},
|
||||||
|
{"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true},
|
||||||
|
{"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true},
|
||||||
|
{"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true},
|
||||||
|
{"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false},
|
||||||
|
{"nil", nil, args{nil}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(tt.d, tt.want) {
|
||||||
|
t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDuration_MarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
d *Duration
|
||||||
|
want []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.d.MarshalJSON()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
125
authority/provisioner/jwk.go
Normal file
125
authority/provisioner/jwk.go
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/crypto/x509util"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
// jwtPayload extends jwt.Claims with step attributes.
|
||||||
|
type jwtPayload struct {
|
||||||
|
jose.Claims
|
||||||
|
SANs []string `json:"sans,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWK is the default provisioner, an entity that can sign tokens necessary for
|
||||||
|
// signature requests.
|
||||||
|
type JWK struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Key *jose.JSONWebKey `json:"key"`
|
||||||
|
EncryptedKey string `json:"encryptedKey,omitempty"`
|
||||||
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
|
audiences []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the provisioner unique identifier. The name and credential id
|
||||||
|
// should uniquely identify any JWK provisioner.
|
||||||
|
func (p *JWK) GetID() string {
|
||||||
|
return p.Name + ":" + p.Key.KeyID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetName returns the name of the provisioner.
|
||||||
|
func (p *JWK) GetName() string {
|
||||||
|
return p.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of provisioner.
|
||||||
|
func (p *JWK) GetType() Type {
|
||||||
|
return TypeJWK
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEncryptedKey returns the base provisioner encrypted key if it's defined.
|
||||||
|
func (p *JWK) GetEncryptedKey() (string, string, bool) {
|
||||||
|
return p.Key.KeyID, p.EncryptedKey, len(p.EncryptedKey) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init initializes and validates the fields of a JWK type.
|
||||||
|
func (p *JWK) Init(config Config) (err error) {
|
||||||
|
switch {
|
||||||
|
case p.Type == "":
|
||||||
|
return errors.New("provisioner type cannot be empty")
|
||||||
|
case p.Name == "":
|
||||||
|
return errors.New("provisioner name cannot be empty")
|
||||||
|
case p.Key == nil:
|
||||||
|
return errors.New("provisioner key cannot be empty")
|
||||||
|
}
|
||||||
|
p.Claims, err = p.Claims.Init(&config.Claims)
|
||||||
|
p.audiences = config.Audiences
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorize validates the given token.
|
||||||
|
func (p *JWK) Authorize(token string) ([]SignOption, error) {
|
||||||
|
jwt, err := jose.ParseSigned(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error parsing token")
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims jwtPayload
|
||||||
|
if err = jwt.Claims(p.Key, &claims); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error parsing claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||||
|
// more than a few minutes.
|
||||||
|
if err = claims.ValidateWithLeeway(jose.Expected{
|
||||||
|
Issuer: p.Name,
|
||||||
|
Time: time.Now().UTC(),
|
||||||
|
}, time.Minute); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate audiences with the defaults
|
||||||
|
if !matchesAudience(claims.Audience, p.audiences) {
|
||||||
|
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Subject == "" {
|
||||||
|
return nil, errors.New("token subject cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: This is for backwards compatibility with older versions of cli
|
||||||
|
// and certificates. Older versions added the token subject as the only SAN
|
||||||
|
// in a CSR by default.
|
||||||
|
if len(claims.SANs) == 0 {
|
||||||
|
claims.SANs = []string{claims.Subject}
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsNames, ips := x509util.SplitSANs(claims.SANs)
|
||||||
|
return []SignOption{
|
||||||
|
commonNameValidator(claims.Subject),
|
||||||
|
dnsNamesValidator(dnsNames),
|
||||||
|
ipAddressesValidator(ips),
|
||||||
|
profileDefaultDuration(p.Claims.DefaultTLSCertDuration()),
|
||||||
|
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
|
||||||
|
newValidityValidator(p.Claims.MinTLSCertDuration(), p.Claims.MaxTLSCertDuration()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRenewal returns an error if the renewal is disabled.
|
||||||
|
func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||||
|
if p.Claims.IsDisableRenewal() {
|
||||||
|
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||||
|
// revoke the certificate with serial number in the `sub` property.
|
||||||
|
func (p *JWK) AuthorizeRevoke(token string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
256
authority/provisioner/jwk_test.go
Normal file
256
authority/provisioner/jwk_test.go
Normal file
|
@ -0,0 +1,256 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultDisableRenewal = false
|
||||||
|
globalProvisionerClaims = Claims{
|
||||||
|
MinTLSDur: &Duration{5 * time.Minute},
|
||||||
|
MaxTLSDur: &Duration{24 * time.Hour},
|
||||||
|
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||||
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJWK_Getters(t *testing.T) {
|
||||||
|
p, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
if got := p.GetID(); got != p.Name+":"+p.Key.KeyID {
|
||||||
|
t.Errorf("JWK.GetID() = %v, want %v:%v", got, p.Name, p.Key.KeyID)
|
||||||
|
}
|
||||||
|
if got := p.GetName(); got != p.Name {
|
||||||
|
t.Errorf("JWK.GetName() = %v, want %v", got, p.Name)
|
||||||
|
}
|
||||||
|
if got := p.GetType(); got != TypeJWK {
|
||||||
|
t.Errorf("JWK.GetType() = %v, want %v", got, TypeJWK)
|
||||||
|
}
|
||||||
|
kid, key, ok := p.GetEncryptedKey()
|
||||||
|
if kid != p.Key.KeyID || key != p.EncryptedKey || ok == false {
|
||||||
|
t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||||
|
kid, key, ok, p.Key.KeyID, p.EncryptedKey, true)
|
||||||
|
}
|
||||||
|
p.EncryptedKey = ""
|
||||||
|
kid, key, ok = p.GetEncryptedKey()
|
||||||
|
if kid != p.Key.KeyID || key != "" || ok == true {
|
||||||
|
t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||||
|
kid, key, ok, p.Key.KeyID, "", false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWK_Init(t *testing.T) {
|
||||||
|
type ProvisionerValidateTest struct {
|
||||||
|
p *JWK
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
tests := map[string]func(*testing.T) ProvisionerValidateTest{
|
||||||
|
"fail-empty": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &JWK{},
|
||||||
|
err: errors.New("provisioner type cannot be empty"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &JWK{
|
||||||
|
Type: "JWK",
|
||||||
|
},
|
||||||
|
err: errors.New("provisioner name cannot be empty"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &JWK{Name: "foo"},
|
||||||
|
err: errors.New("provisioner type cannot be empty"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail-empty-key": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &JWK{Name: "foo", Type: "bar"},
|
||||||
|
err: errors.New("provisioner key cannot be empty"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config := Config{
|
||||||
|
Claims: globalProvisionerClaims,
|
||||||
|
Audiences: testAudiences,
|
||||||
|
}
|
||||||
|
for name, get := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := get(t)
|
||||||
|
err := tc.p.Init(config)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.err.Error(), err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWK_Authorize(t *testing.T) {
|
||||||
|
p1, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
key1, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
key2, err := decryptJSONWebKey(p2.EncryptedKey)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], "", []string{}, time.Now(), key1)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// Invalid tokens
|
||||||
|
parts := strings.Split(t1, ".")
|
||||||
|
key3, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// missing key
|
||||||
|
failKey, err := generateSimpleToken(p1.Name, testAudiences[0], key3)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// invalid token
|
||||||
|
failTok := "foo." + parts[1] + "." + parts[2]
|
||||||
|
// invalid claims
|
||||||
|
failClaims := parts[0] + ".foo." + parts[1]
|
||||||
|
// invalid issuer
|
||||||
|
failIss, err := generateSimpleToken("foobar", testAudiences[0], key1)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// invalid audience
|
||||||
|
failAud, err := generateSimpleToken(p1.Name, "foobar", key1)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// invalid signature
|
||||||
|
failSig := t1[0 : len(t1)-2]
|
||||||
|
// no subject
|
||||||
|
failSub, err := generateToken("", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now(), key1)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// expired
|
||||||
|
failExp, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// not before
|
||||||
|
failNbf, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// Remove encrypted key for p2
|
||||||
|
p2.EncryptedKey = ""
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *JWK
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", p1, args{t1}, false},
|
||||||
|
{"ok-no-encrypted-key", p2, args{t2}, false},
|
||||||
|
{"ok-no-sans", p1, args{t3}, false},
|
||||||
|
{"fail-key", p1, args{failKey}, true},
|
||||||
|
{"fail-token", p1, args{failTok}, true},
|
||||||
|
{"fail-claims", p1, args{failClaims}, true},
|
||||||
|
{"fail-issuer", p1, args{failIss}, true},
|
||||||
|
{"fail-audience", p1, args{failAud}, true},
|
||||||
|
{"fail-signature", p1, args{failSig}, true},
|
||||||
|
{"fail-subject", p1, args{failSub}, true},
|
||||||
|
{"fail-expired", p1, args{failExp}, true},
|
||||||
|
{"fail-not-before", p1, args{failNbf}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.prov.Authorize(tt.args.token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("JWK.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
assert.Nil(t, got)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, got)
|
||||||
|
assert.Len(t, 6, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWK_AuthorizeRenewal(t *testing.T) {
|
||||||
|
p1, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// disable renewal
|
||||||
|
disable := true
|
||||||
|
p2.Claims = &Claims{
|
||||||
|
globalClaims: &globalProvisionerClaims,
|
||||||
|
DisableRenewal: &disable,
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *JWK
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", p1, args{nil}, false},
|
||||||
|
{"fail", p2, args{nil}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("JWK.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWK_AuthorizeRevoke(t *testing.T) {
|
||||||
|
p1, err := generateJWK()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
key1, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *JWK
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"disabled", p1, args{t1}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("JWK.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
135
authority/provisioner/keystore.go
Normal file
135
authority/provisioner/keystore.go
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultCacheAge = 12 * time.Hour
|
||||||
|
defaultCacheJitter = 1 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
||||||
|
|
||||||
|
type keyStore struct {
|
||||||
|
sync.RWMutex
|
||||||
|
uri string
|
||||||
|
keySet jose.JSONWebKeySet
|
||||||
|
timer *time.Timer
|
||||||
|
expiry time.Time
|
||||||
|
jitter time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func newKeyStore(uri string) (*keyStore, error) {
|
||||||
|
keys, age, err := getKeysFromJWKsURI(uri)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ks := &keyStore{
|
||||||
|
uri: uri,
|
||||||
|
keySet: keys,
|
||||||
|
expiry: getExpirationTime(age),
|
||||||
|
jitter: getCacheJitter(age),
|
||||||
|
}
|
||||||
|
next := ks.nextReloadDuration(age)
|
||||||
|
ks.timer = time.AfterFunc(next, ks.reload)
|
||||||
|
return ks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *keyStore) Close() {
|
||||||
|
ks.timer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
||||||
|
ks.RLock()
|
||||||
|
// Force reload if expiration has passed
|
||||||
|
if time.Now().After(ks.expiry) {
|
||||||
|
ks.RUnlock()
|
||||||
|
ks.reload()
|
||||||
|
ks.RLock()
|
||||||
|
}
|
||||||
|
keys = ks.keySet.Key(kid)
|
||||||
|
ks.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *keyStore) reload() {
|
||||||
|
var next time.Duration
|
||||||
|
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
||||||
|
if err != nil {
|
||||||
|
next = ks.nextReloadDuration(ks.jitter / 2)
|
||||||
|
} else {
|
||||||
|
ks.Lock()
|
||||||
|
ks.keySet = keys
|
||||||
|
ks.expiry = getExpirationTime(age)
|
||||||
|
ks.jitter = getCacheJitter(age)
|
||||||
|
next = ks.nextReloadDuration(age)
|
||||||
|
ks.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.Lock()
|
||||||
|
ks.timer.Reset(next)
|
||||||
|
ks.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||||
|
n := rand.Int63n(int64(ks.jitter))
|
||||||
|
age -= time.Duration(n)
|
||||||
|
if age < 0 {
|
||||||
|
age = 0
|
||||||
|
}
|
||||||
|
return age
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||||
|
var keys jose.JSONWebKeySet
|
||||||
|
resp, err := http.Get(uri)
|
||||||
|
if err != nil {
|
||||||
|
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
|
||||||
|
return keys, 0, errors.Wrapf(err, "error reading %s", uri)
|
||||||
|
}
|
||||||
|
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCacheAge(cacheControl string) time.Duration {
|
||||||
|
age := defaultCacheAge
|
||||||
|
if len(cacheControl) > 0 {
|
||||||
|
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
|
||||||
|
if len(match) > 0 {
|
||||||
|
if len(match[0]) == 2 {
|
||||||
|
maxAge := match[0][1]
|
||||||
|
maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return defaultCacheAge
|
||||||
|
}
|
||||||
|
age = time.Duration(maxAgeInt) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return age
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCacheJitter(age time.Duration) time.Duration {
|
||||||
|
switch {
|
||||||
|
case age > time.Hour:
|
||||||
|
return defaultCacheJitter
|
||||||
|
default:
|
||||||
|
return age / 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getExpirationTime(age time.Duration) time.Time {
|
||||||
|
return time.Now().Truncate(time.Second).Add(age)
|
||||||
|
}
|
121
authority/provisioner/keystore_test.go
Normal file
121
authority/provisioner/keystore_test.go
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_newKeyStore(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
ks, err := newKeyStore(srv.URL)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer ks.Close()
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
uri string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want jose.JSONWebKeySet
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{srv.URL}, ks.keySet, false},
|
||||||
|
{"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := newKeyStore(tt.args.uri)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
if !reflect.DeepEqual(got.keySet, tt.want) {
|
||||||
|
t.Errorf("newKeyStore() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
got.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_keyStore(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
ks, err := newKeyStore(srv.URL + "/random")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer ks.Close()
|
||||||
|
ks.RLock()
|
||||||
|
keySet1 := ks.keySet
|
||||||
|
ks.RUnlock()
|
||||||
|
// Check contents
|
||||||
|
assert.Len(t, 2, keySet1.Keys)
|
||||||
|
assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID))
|
||||||
|
assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID))
|
||||||
|
assert.Len(t, 0, ks.Get("foobar"))
|
||||||
|
|
||||||
|
// Wait for rotation
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
ks.RLock()
|
||||||
|
keySet2 := ks.keySet
|
||||||
|
ks.RUnlock()
|
||||||
|
if reflect.DeepEqual(keySet1, keySet2) {
|
||||||
|
t.Error("keyStore did not rotated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check contents
|
||||||
|
assert.Len(t, 2, keySet2.Keys)
|
||||||
|
assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID))
|
||||||
|
assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID))
|
||||||
|
assert.Len(t, 0, ks.Get("foobar"))
|
||||||
|
|
||||||
|
// Check hits
|
||||||
|
resp, err := srv.Client().Get(srv.URL + "/hits")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
hits := struct {
|
||||||
|
Hits int `json:"hits"`
|
||||||
|
}{}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&hits)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_keyStore_Get(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
ks, err := newKeyStore(srv.URL)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer ks.Close()
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
kid string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ks *keyStore
|
||||||
|
args args
|
||||||
|
wantKeys []jose.JSONWebKey
|
||||||
|
}{
|
||||||
|
{"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}},
|
||||||
|
{"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}},
|
||||||
|
{"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) {
|
||||||
|
t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
37
authority/provisioner/noop.go
Normal file
37
authority/provisioner/noop.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import "crypto/x509"
|
||||||
|
|
||||||
|
// noop provisioners is a provisioner that accepts anything.
|
||||||
|
type noop struct{}
|
||||||
|
|
||||||
|
func (p *noop) GetID() string {
|
||||||
|
return "noop"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *noop) GetName() string {
|
||||||
|
return "noop"
|
||||||
|
}
|
||||||
|
func (p *noop) GetType() Type {
|
||||||
|
return noopType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *noop) Init(config Config) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *noop) Authorize(token string) ([]SignOption, error) {
|
||||||
|
return []SignOption{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *noop) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *noop) AuthorizeRevoke(token string) error {
|
||||||
|
return nil
|
||||||
|
}
|
27
authority/provisioner/noop_test.go
Normal file
27
authority/provisioner/noop_test.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_noop(t *testing.T) {
|
||||||
|
p := noop{}
|
||||||
|
assert.Equals(t, "noop", p.GetID())
|
||||||
|
assert.Equals(t, "noop", p.GetName())
|
||||||
|
assert.Equals(t, noopType, p.GetType())
|
||||||
|
assert.Equals(t, nil, p.Init(Config{}))
|
||||||
|
assert.Equals(t, nil, p.AuthorizeRenewal(&x509.Certificate{}))
|
||||||
|
assert.Equals(t, nil, p.AuthorizeRevoke("foo"))
|
||||||
|
|
||||||
|
kid, key, ok := p.GetEncryptedKey()
|
||||||
|
assert.Equals(t, "", kid)
|
||||||
|
assert.Equals(t, "", key)
|
||||||
|
assert.Equals(t, false, ok)
|
||||||
|
|
||||||
|
sigOptions, err := p.Authorize("foo")
|
||||||
|
assert.Equals(t, []SignOption{}, sigOptions)
|
||||||
|
assert.Equals(t, nil, err)
|
||||||
|
}
|
243
authority/provisioner/oidc.go
Normal file
243
authority/provisioner/oidc.go
Normal file
|
@ -0,0 +1,243 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
// openIDConfiguration contains the necessary properties in the
|
||||||
|
// `/.well-known/openid-configuration` document.
|
||||||
|
type openIDConfiguration struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
JWKSetURI string `json:"jwks_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the values in a well-known OpenID configuration endpoint.
|
||||||
|
func (c openIDConfiguration) Validate() error {
|
||||||
|
switch {
|
||||||
|
case c.Issuer == "":
|
||||||
|
return errors.New("issuer cannot be empty")
|
||||||
|
case c.JWKSetURI == "":
|
||||||
|
return errors.New("jwks_uri cannot be empty")
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// openIDPayload represents the fields on the id_token JWT payload.
|
||||||
|
type openIDPayload struct {
|
||||||
|
jose.Claims
|
||||||
|
AtHash string `json:"at_hash"`
|
||||||
|
AuthorizedParty string `json:"azp"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
Hd string `json:"hd"`
|
||||||
|
Nonce string `json:"nonce"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDC represents an OAuth 2.0 OpenID Connect provider.
|
||||||
|
//
|
||||||
|
// ClientSecret is mandatory, but it can be an empty string.
|
||||||
|
type OIDC struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
ClientID string `json:"clientID"`
|
||||||
|
ClientSecret string `json:"clientSecret"`
|
||||||
|
ConfigurationEndpoint string `json:"configurationEndpoint"`
|
||||||
|
Admins []string `json:"admins,omitempty"`
|
||||||
|
Domains []string `json:"domains,omitempty"`
|
||||||
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
|
configuration openIDConfiguration
|
||||||
|
keyStore *keyStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAdmin returns true if the given email is in the Admins whitelist, false
|
||||||
|
// otherwise.
|
||||||
|
func (o *OIDC) IsAdmin(email string) bool {
|
||||||
|
email = sanitizeEmail(email)
|
||||||
|
for _, e := range o.Admins {
|
||||||
|
if email == sanitizeEmail(e) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeEmail(email string) string {
|
||||||
|
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||||
|
email = email[:i] + strings.ToLower(email[i:])
|
||||||
|
}
|
||||||
|
return email
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the provisioner unique identifier, the OIDC provisioner the
|
||||||
|
// uses the clientID for this.
|
||||||
|
func (o *OIDC) GetID() string {
|
||||||
|
return o.ClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetName returns the name of the provisioner.
|
||||||
|
func (o *OIDC) GetName() string {
|
||||||
|
return o.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of provisioner.
|
||||||
|
func (o *OIDC) GetType() Type {
|
||||||
|
return TypeOIDC
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEncryptedKey is not available in an OIDC provisioner.
|
||||||
|
func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init validates and initializes the OIDC provider.
|
||||||
|
func (o *OIDC) Init(config Config) (err error) {
|
||||||
|
switch {
|
||||||
|
case o.Type == "":
|
||||||
|
return errors.New("type cannot be empty")
|
||||||
|
case o.Name == "":
|
||||||
|
return errors.New("name cannot be empty")
|
||||||
|
case o.ClientID == "":
|
||||||
|
return errors.New("clientID cannot be empty")
|
||||||
|
case o.ConfigurationEndpoint == "":
|
||||||
|
return errors.New("configurationEndpoint cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update claims with global ones
|
||||||
|
if o.Claims, err = o.Claims.Init(&config.Claims); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Decode and validate openid-configuration endpoint
|
||||||
|
if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := o.configuration.Validate(); err != nil {
|
||||||
|
return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint)
|
||||||
|
}
|
||||||
|
// Get JWK key set
|
||||||
|
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidatePayload validates the given token payload.
|
||||||
|
func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
||||||
|
// According to "rfc7519 JSON Web Token" acceptable skew should be no more
|
||||||
|
// than a few minutes.
|
||||||
|
if err := p.ValidateWithLeeway(jose.Expected{
|
||||||
|
Issuer: o.configuration.Issuer,
|
||||||
|
Audience: jose.Audience{o.ClientID},
|
||||||
|
Time: time.Now().UTC(),
|
||||||
|
}, time.Minute); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to validate payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate azp if present
|
||||||
|
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
|
||||||
|
return errors.New("failed to validate payload: invalid azp")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enforce an email claim
|
||||||
|
if p.Email == "" {
|
||||||
|
return errors.New("failed to validate payload: email not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate domains (case-insensitive)
|
||||||
|
if !o.IsAdmin(p.Email) && len(o.Domains) > 0 {
|
||||||
|
email := sanitizeEmail(p.Email)
|
||||||
|
var found bool
|
||||||
|
for _, d := range o.Domains {
|
||||||
|
if strings.HasSuffix(email, "@"+strings.ToLower(d)) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return errors.New("failed to validate payload: email is not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorize validates the given token.
|
||||||
|
func (o *OIDC) Authorize(token string) ([]SignOption, error) {
|
||||||
|
jwt, err := jose.ParseSigned(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error parsing token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse claims to get the kid
|
||||||
|
var claims openIDPayload
|
||||||
|
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error parsing claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
kid := jwt.Headers[0].KeyID
|
||||||
|
keys := o.keyStore.Get(kid)
|
||||||
|
for _, key := range keys {
|
||||||
|
if err := jwt.Claims(key, &claims); err == nil {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, errors.New("cannot validate token")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := o.ValidatePayload(claims); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admins should be able to authorize any SAN
|
||||||
|
if o.IsAdmin(claims.Email) {
|
||||||
|
return []SignOption{
|
||||||
|
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
|
||||||
|
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||||
|
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return []SignOption{
|
||||||
|
emailOnlyIdentity(claims.Email),
|
||||||
|
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
|
||||||
|
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||||
|
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRenewal returns an error if the renewal is disabled.
|
||||||
|
func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||||
|
if o.Claims.IsDisableRenewal() {
|
||||||
|
return errors.Errorf("renew is disabled for provisioner %s", o.GetID())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||||
|
// revoke the certificate with serial number in the `sub` property.
|
||||||
|
func (o *OIDC) AuthorizeRevoke(token string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAndDecode(uri string, v interface{}) error {
|
||||||
|
resp, err := http.Get(uri)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrapf(err, "failed to connect to %s", uri)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
|
||||||
|
return errors.Wrapf(err, "error reading %s", uri)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
327
authority/provisioner/oidc_test.go
Normal file
327
authority/provisioner/oidc_test.go
Normal file
|
@ -0,0 +1,327 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_openIDConfiguration_Validate(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Issuer string
|
||||||
|
JWKSetURI string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{"the-issuer", "the-jwks-uri"}, false},
|
||||||
|
{"no-issuer", fields{"", "the-jwks-uri"}, true},
|
||||||
|
{"no-jwks-uri", fields{"the-issuer", ""}, true},
|
||||||
|
{"empty", fields{"", ""}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := openIDConfiguration{
|
||||||
|
Issuer: tt.fields.Issuer,
|
||||||
|
JWKSetURI: tt.fields.JWKSetURI,
|
||||||
|
}
|
||||||
|
if err := c.Validate(); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("openIDConfiguration.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOIDC_Getters(t *testing.T) {
|
||||||
|
p, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
if got := p.GetID(); got != p.ClientID {
|
||||||
|
t.Errorf("OIDC.GetID() = %v, want %v", got, p.ClientID)
|
||||||
|
}
|
||||||
|
if got := p.GetName(); got != p.Name {
|
||||||
|
t.Errorf("OIDC.GetName() = %v, want %v", got, p.Name)
|
||||||
|
}
|
||||||
|
if got := p.GetType(); got != TypeOIDC {
|
||||||
|
t.Errorf("OIDC.GetType() = %v, want %v", got, TypeOIDC)
|
||||||
|
}
|
||||||
|
kid, key, ok := p.GetEncryptedKey()
|
||||||
|
if kid != "" || key != "" || ok == true {
|
||||||
|
t.Errorf("OIDC.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||||
|
kid, key, ok, "", "", false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOIDC_Init(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
config := Config{
|
||||||
|
Claims: globalProvisionerClaims,
|
||||||
|
}
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
ConfigurationEndpoint string
|
||||||
|
Claims *Claims
|
||||||
|
Admins []string
|
||||||
|
Domains []string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
config Config
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
|
||||||
|
{"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}, nil}, args{config}, false},
|
||||||
|
{"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, []string{"smallstep.com"}}, args{config}, false},
|
||||||
|
{"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
|
||||||
|
{"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||||
|
{"no-type", fields{"", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||||
|
{"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||||
|
{"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil}, args{config}, true},
|
||||||
|
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil}, args{config}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &OIDC{
|
||||||
|
Type: tt.fields.Type,
|
||||||
|
Name: tt.fields.Name,
|
||||||
|
ClientID: tt.fields.ClientID,
|
||||||
|
ConfigurationEndpoint: tt.fields.ConfigurationEndpoint,
|
||||||
|
Claims: tt.fields.Claims,
|
||||||
|
Admins: tt.fields.Admins,
|
||||||
|
}
|
||||||
|
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if tt.wantErr == false {
|
||||||
|
assert.Len(t, 2, p.keyStore.keySet.Keys)
|
||||||
|
assert.Equals(t, openIDConfiguration{
|
||||||
|
Issuer: "the-issuer",
|
||||||
|
JWKSetURI: srv.URL + "/jwks_uri",
|
||||||
|
}, p.configuration)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOIDC_Authorize(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
var keys jose.JSONWebKeySet
|
||||||
|
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||||
|
|
||||||
|
// Create test provisioners
|
||||||
|
p1, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p3, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// Admin + Domains
|
||||||
|
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
|
||||||
|
p3.Domains = []string{"smallstep.com"}
|
||||||
|
|
||||||
|
// Update configuration endpoints and initialize
|
||||||
|
config := Config{Claims: globalProvisionerClaims}
|
||||||
|
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||||
|
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||||
|
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||||
|
assert.FatalError(t, p1.Init(config))
|
||||||
|
assert.FatalError(t, p2.Init(config))
|
||||||
|
assert.FatalError(t, p3.Init(config))
|
||||||
|
|
||||||
|
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t2, err := generateSimpleToken("the-issuer", p2.ClientID, &keys.Keys[1])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t3, err := generateSimpleToken("the-issuer", p3.ClientID, &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// Admin email not in domains
|
||||||
|
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// Invalid email
|
||||||
|
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
failDomain, err := generateToken("subject", "the-issuer", p3.ClientID, "name@example.com", []string{}, time.Now(), &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// Invalid tokens
|
||||||
|
parts := strings.Split(t1, ".")
|
||||||
|
key, err := generateJSONWebKey()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// missing key
|
||||||
|
failKey, err := generateSimpleToken("the-issuer", p1.ClientID, key)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// invalid token
|
||||||
|
failTok := "foo." + parts[1] + "." + parts[2]
|
||||||
|
// invalid claims
|
||||||
|
failClaims := parts[0] + ".foo." + parts[1]
|
||||||
|
// invalid issuer
|
||||||
|
failIss, err := generateSimpleToken("bad-issuer", p1.ClientID, &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// invalid audience
|
||||||
|
failAud, err := generateSimpleToken("the-issuer", "foobar", &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// invalid signature
|
||||||
|
failSig := t1[0 : len(t1)-2]
|
||||||
|
// expired
|
||||||
|
failExp, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(-360*time.Second), &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// not before
|
||||||
|
failNbf, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(360*time.Second), &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *OIDC
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok1", p1, args{t1}, false},
|
||||||
|
{"ok2", p2, args{t2}, false},
|
||||||
|
{"admin", p3, args{t3}, false},
|
||||||
|
{"admin", p3, args{okAdmin}, false},
|
||||||
|
{"fail-email", p3, args{failEmail}, true},
|
||||||
|
{"fail-domain", p3, args{failDomain}, true},
|
||||||
|
{"fail-key", p1, args{failKey}, true},
|
||||||
|
{"fail-token", p1, args{failTok}, true},
|
||||||
|
{"fail-claims", p1, args{failClaims}, true},
|
||||||
|
{"fail-issuer", p1, args{failIss}, true},
|
||||||
|
{"fail-audience", p1, args{failAud}, true},
|
||||||
|
{"fail-signature", p1, args{failSig}, true},
|
||||||
|
{"fail-expired", p1, args{failExp}, true},
|
||||||
|
{"fail-not-before", p1, args{failNbf}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.prov.Authorize(tt.args.token)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
fmt.Println(tt)
|
||||||
|
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
assert.Nil(t, got)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, got)
|
||||||
|
if tt.name == "admin" {
|
||||||
|
assert.Len(t, 3, got)
|
||||||
|
} else {
|
||||||
|
assert.Len(t, 4, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOIDC_AuthorizeRenewal(t *testing.T) {
|
||||||
|
p1, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// disable renewal
|
||||||
|
disable := true
|
||||||
|
p2.Claims = &Claims{
|
||||||
|
globalClaims: &globalProvisionerClaims,
|
||||||
|
DisableRenewal: &disable,
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *OIDC
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", p1, args{nil}, false},
|
||||||
|
{"fail", p2, args{nil}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("OIDC.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
var keys jose.JSONWebKeySet
|
||||||
|
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||||
|
|
||||||
|
// Create test provisioners
|
||||||
|
p1, err := generateOIDC()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// Update configuration endpoints and initialize
|
||||||
|
config := Config{Claims: globalProvisionerClaims}
|
||||||
|
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||||
|
assert.FatalError(t, p1.Init(config))
|
||||||
|
|
||||||
|
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *OIDC
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"disabled", p1, args{t1}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("OIDC.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_sanitizeEmail(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
email string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"equal", "name@smallstep.com", "name@smallstep.com"},
|
||||||
|
{"domain-insensitive", "name@SMALLSTEP.COM", "name@smallstep.com"},
|
||||||
|
{"local-sensitive", "NaMe@smallSTEP.CoM", "NaMe@smallstep.com"},
|
||||||
|
{"multiple-@", "NaMe@NaMe@smallSTEP.CoM", "NaMe@NaMe@smallstep.com"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := sanitizeEmail(tt.email); got != tt.want {
|
||||||
|
t.Errorf("sanitizeEmail() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
82
authority/provisioner/provisioner.go
Normal file
82
authority/provisioner/provisioner.go
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Interface is the interface that all provisioner types must implement.
|
||||||
|
type Interface interface {
|
||||||
|
GetID() string
|
||||||
|
GetName() string
|
||||||
|
GetType() Type
|
||||||
|
GetEncryptedKey() (kid string, key string, ok bool)
|
||||||
|
Init(config Config) error
|
||||||
|
Authorize(token string) ([]SignOption, error)
|
||||||
|
AuthorizeRenewal(cert *x509.Certificate) error
|
||||||
|
AuthorizeRevoke(token string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type indicates the provisioner Type.
|
||||||
|
type Type int
|
||||||
|
|
||||||
|
const (
|
||||||
|
noopType Type = 0
|
||||||
|
|
||||||
|
// TypeJWK is used to indicate the JWK provisioners.
|
||||||
|
TypeJWK Type = 1
|
||||||
|
|
||||||
|
// TypeOIDC is used to indicate the OIDC provisioners.
|
||||||
|
TypeOIDC Type = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config defines the default parameters used in the initialization of
|
||||||
|
// provisioners.
|
||||||
|
type Config struct {
|
||||||
|
// Claims are the default claims.
|
||||||
|
Claims Claims
|
||||||
|
// Audiences are the audiences used in the default provisioner, (JWK).
|
||||||
|
Audiences []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type provisioner struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// List represents a list of provisioners.
|
||||||
|
type List []Interface
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler and allows to unmarshal a list of a
|
||||||
|
// interfaces into the right type.
|
||||||
|
func (l *List) UnmarshalJSON(data []byte) error {
|
||||||
|
ps := []json.RawMessage{}
|
||||||
|
if err := json.Unmarshal(data, &ps); err != nil {
|
||||||
|
return errors.Wrap(err, "error unmarshaling provisioner list")
|
||||||
|
}
|
||||||
|
|
||||||
|
*l = List{}
|
||||||
|
for _, data := range ps {
|
||||||
|
var typ provisioner
|
||||||
|
if err := json.Unmarshal(data, &typ); err != nil {
|
||||||
|
return errors.Errorf("error unmarshaling provisioner")
|
||||||
|
}
|
||||||
|
var p Interface
|
||||||
|
switch strings.ToLower(typ.Type) {
|
||||||
|
case "jwk":
|
||||||
|
p = &JWK{}
|
||||||
|
case "oidc":
|
||||||
|
p = &OIDC{}
|
||||||
|
default:
|
||||||
|
return errors.Errorf("provisioner type %s not supported", typ.Type)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, p); err != nil {
|
||||||
|
return errors.Errorf("error unmarshaling provisioner")
|
||||||
|
}
|
||||||
|
*l = append(*l, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
233
authority/provisioner/sign_options.go
Normal file
233
authority/provisioner/sign_options.go
Normal file
|
@ -0,0 +1,233 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/crypto/x509util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Options contains the options that can be passed to the Sign method.
|
||||||
|
type Options struct {
|
||||||
|
NotAfter time.Time `json:"notAfter"`
|
||||||
|
NotBefore time.Time `json:"notBefore"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignOption is the interface used to collect all extra options used in the
|
||||||
|
// Sign method.
|
||||||
|
type SignOption interface{}
|
||||||
|
|
||||||
|
// CertificateValidator is the interface used to validate a X.509 certificate.
|
||||||
|
type CertificateValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(crt *x509.Certificate) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// CertificateRequestValidator is the interface used to validate a X.509
|
||||||
|
// certificate request.
|
||||||
|
type CertificateRequestValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(req *x509.CertificateRequest) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProfileModifier is the interface used to add custom options to the profile
|
||||||
|
// constructor. The options are used to modify the final certificate.
|
||||||
|
type ProfileModifier interface {
|
||||||
|
SignOption
|
||||||
|
Option(o Options) x509util.WithOption
|
||||||
|
}
|
||||||
|
|
||||||
|
// profileWithOption is a wrapper against x509util.WithOption to conform the
|
||||||
|
// interface.
|
||||||
|
type profileWithOption x509util.WithOption
|
||||||
|
|
||||||
|
func (v profileWithOption) Option(Options) x509util.WithOption {
|
||||||
|
return x509util.WithOption(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// profileDefaultDuration is a wrapper against x509util.WithOption to conform the
|
||||||
|
// interface.
|
||||||
|
type profileDefaultDuration time.Duration
|
||||||
|
|
||||||
|
func (v profileDefaultDuration) Option(so Options) x509util.WithOption {
|
||||||
|
return x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, time.Duration(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
||||||
|
// SAN provided is the given email address.
|
||||||
|
type emailOnlyIdentity string
|
||||||
|
|
||||||
|
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
|
||||||
|
switch {
|
||||||
|
case len(req.DNSNames) > 0:
|
||||||
|
return errors.New("certificate request cannot contain DNS names")
|
||||||
|
case len(req.IPAddresses) > 0:
|
||||||
|
return errors.New("certificate request cannot contain IP addresses")
|
||||||
|
case len(req.URIs) > 0:
|
||||||
|
return errors.New("certificate request cannot contain URIs")
|
||||||
|
case len(req.EmailAddresses) == 0:
|
||||||
|
return errors.New("certificate request does not contain any email address")
|
||||||
|
case len(req.EmailAddresses) > 1:
|
||||||
|
return errors.New("certificate request does not contain too many email addresses")
|
||||||
|
case req.EmailAddresses[0] == "":
|
||||||
|
return errors.New("certificate request cannot contain an empty email address")
|
||||||
|
case req.EmailAddresses[0] != string(e):
|
||||||
|
return errors.Errorf("certificate request does not contain the valid email address, got %s, want %s", req.EmailAddresses[0], e)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// commonNameValidator validates the common name of a certificate request.
|
||||||
|
type commonNameValidator string
|
||||||
|
|
||||||
|
// Valid checks that certificate request common name matches the one configured.
|
||||||
|
func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
|
if req.Subject.CommonName == "" {
|
||||||
|
return errors.New("certificate request cannot contain an empty common name")
|
||||||
|
}
|
||||||
|
if req.Subject.CommonName != string(v) {
|
||||||
|
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsNamesValidator validates the DNS names SAN of a certificate request.
|
||||||
|
type dnsNamesValidator []string
|
||||||
|
|
||||||
|
// Valid checks that certificate request DNS Names match those configured in
|
||||||
|
// the bootstrap (token) flow.
|
||||||
|
func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
|
want := make(map[string]bool)
|
||||||
|
for _, s := range v {
|
||||||
|
want[s] = true
|
||||||
|
}
|
||||||
|
got := make(map[string]bool)
|
||||||
|
for _, s := range req.DNSNames {
|
||||||
|
got[s] = true
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
return errors.Errorf("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipAddressesValidator validates the IP addresses SAN of a certificate request.
|
||||||
|
type ipAddressesValidator []net.IP
|
||||||
|
|
||||||
|
// Valid checks that certificate request IP Addresses match those configured in
|
||||||
|
// the bootstrap (token) flow.
|
||||||
|
func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error {
|
||||||
|
want := make(map[string]bool)
|
||||||
|
for _, ip := range v {
|
||||||
|
want[ip.String()] = true
|
||||||
|
}
|
||||||
|
got := make(map[string]bool)
|
||||||
|
for _, ip := range req.IPAddresses {
|
||||||
|
got[ip.String()] = true
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
return errors.Errorf("IP Addresses claim failed - got %v, want %v", req.IPAddresses, v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validityValidator validates the certificate temporal validity settings.
|
||||||
|
type validityValidator struct {
|
||||||
|
min time.Duration
|
||||||
|
max time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// newValidityValidator return a new validity validator.
|
||||||
|
func newValidityValidator(min, max time.Duration) *validityValidator {
|
||||||
|
return &validityValidator{min: min, max: max}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the certificate temporal validity settings.
|
||||||
|
func (v *validityValidator) Valid(crt *x509.Certificate) error {
|
||||||
|
var (
|
||||||
|
na = crt.NotAfter
|
||||||
|
nb = crt.NotBefore
|
||||||
|
d = na.Sub(nb)
|
||||||
|
now = time.Now()
|
||||||
|
)
|
||||||
|
|
||||||
|
if na.Before(now) {
|
||||||
|
return errors.Errorf("NotAfter: %v cannot be in the past", na)
|
||||||
|
}
|
||||||
|
if na.Before(nb) {
|
||||||
|
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
|
||||||
|
}
|
||||||
|
if d < v.min {
|
||||||
|
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
|
||||||
|
d, v.min)
|
||||||
|
}
|
||||||
|
if d > v.max {
|
||||||
|
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
|
||||||
|
d, v.max)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||||
|
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
||||||
|
)
|
||||||
|
|
||||||
|
type stepProvisionerASN1 struct {
|
||||||
|
Type int
|
||||||
|
Name []byte
|
||||||
|
CredentialID []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type provisionerExtensionOption struct {
|
||||||
|
Type int
|
||||||
|
Name string
|
||||||
|
CredentialID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisionerExtensionOption {
|
||||||
|
return &provisionerExtensionOption{
|
||||||
|
Type: int(typ),
|
||||||
|
Name: name,
|
||||||
|
CredentialID: credentialID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
|
||||||
|
return func(p x509util.Profile) error {
|
||||||
|
crt := p.Subject()
|
||||||
|
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
crt.ExtraExtensions = append(crt.ExtraExtensions, ext)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) {
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||||
|
Type: typ,
|
||||||
|
Name: []byte(name),
|
||||||
|
CredentialID: []byte(credentialID),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
|
||||||
|
}
|
||||||
|
return pkix.Extension{
|
||||||
|
Id: stepOIDProvisioner,
|
||||||
|
Critical: false,
|
||||||
|
Value: b,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Avoid deadcode warning in profileWithOption
|
||||||
|
_ = profileWithOption(nil)
|
||||||
|
}
|
152
authority/provisioner/sign_options_test.go
Normal file
152
authority/provisioner/sign_options_test.go
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_emailOnlyIdentity_Valid(t *testing.T) {
|
||||||
|
uri, err := url.Parse("https://example.com/1.0/getUser")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
req *x509.CertificateRequest
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
e emailOnlyIdentity
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com"}}}, false},
|
||||||
|
{"DNSNames", "name@smallstep.com", args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, true},
|
||||||
|
{"IPAddresses", "name@smallstep.com", args{&x509.CertificateRequest{IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}}}, true},
|
||||||
|
{"URIs", "name@smallstep.com", args{&x509.CertificateRequest{URIs: []*url.URL{uri}}}, true},
|
||||||
|
{"no-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{}}}, true},
|
||||||
|
{"empty-email", "", args{&x509.CertificateRequest{EmailAddresses: []string{""}}}, true},
|
||||||
|
{"multiple-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com", "foo@smallstep.com"}}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.e.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("emailOnlyIdentity.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_commonNameValidator_Valid(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
req *x509.CertificateRequest
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
v commonNameValidator
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false},
|
||||||
|
{"empty", "", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: ""}}}, true},
|
||||||
|
{"wrong", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "example.com"}}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("commonNameValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_dnsNamesValidator_Valid(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
req *x509.CertificateRequest
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
v dnsNamesValidator
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok0", []string{}, args{&x509.CertificateRequest{DNSNames: []string{}}}, false},
|
||||||
|
{"ok1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, false},
|
||||||
|
{"ok2", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "bar.zar"}}}, false},
|
||||||
|
{"ok3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, false},
|
||||||
|
{"fail1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar"}}}, true},
|
||||||
|
{"fail2", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, true},
|
||||||
|
{"fail3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "zar.bar"}}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("dnsNamesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_ipAddressesValidator_Valid(t *testing.T) {
|
||||||
|
ip1 := net.IPv4(10, 3, 2, 1)
|
||||||
|
ip2 := net.IPv4(10, 3, 2, 2)
|
||||||
|
ip3 := net.IPv4(10, 3, 2, 3)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
req *x509.CertificateRequest
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
v ipAddressesValidator
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok0", []net.IP{}, args{&x509.CertificateRequest{IPAddresses: []net.IP{}}}, false},
|
||||||
|
{"ok1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1}}}, false},
|
||||||
|
{"ok2", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip2}}}, false},
|
||||||
|
{"ok3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, false},
|
||||||
|
{"fail1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2}}}, true},
|
||||||
|
{"fail2", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, true},
|
||||||
|
{"fail3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip3}}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ipAddressesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validityValidator_Valid(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
min time.Duration
|
||||||
|
max time.Duration
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
crt *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := &validityValidator{
|
||||||
|
min: tt.fields.min,
|
||||||
|
max: tt.fields.max,
|
||||||
|
}
|
||||||
|
if err := v.Valid(tt.args.crt); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validityValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
11
authority/provisioner/testdata/root_ca.crt
vendored
Normal file
11
authority/provisioner/testdata/root_ca.crt
vendored
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf
|
||||||
|
MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla
|
||||||
|
Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg
|
||||||
|
Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN
|
||||||
|
Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw
|
||||||
|
QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU
|
||||||
|
B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c
|
||||||
|
ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET
|
||||||
|
/A8LXNH4M06A7vE=
|
||||||
|
-----END CERTIFICATE-----
|
272
authority/provisioner/utils_test.go
Normal file
272
authority/provisioner/utils_test.go
Normal file
|
@ -0,0 +1,272 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testAudiences = []string{
|
||||||
|
"https://ca.smallstep.com/sign",
|
||||||
|
"https://ca.smallsteomcom/1.0/sign",
|
||||||
|
}
|
||||||
|
|
||||||
|
func must(args ...interface{}) []interface{} {
|
||||||
|
if l := len(args); l > 0 && args[l-1] != nil {
|
||||||
|
if err, ok := args[l-1].(error); ok {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateJSONWebKey() (*jose.JSONWebKey, error) {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fp, err := jwk.Thumbprint(crypto.SHA256)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jwk.KeyID = string(hex.EncodeToString(fp))
|
||||||
|
return jwk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) {
|
||||||
|
var keySet jose.JSONWebKeySet
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
key, err := generateJSONWebKey()
|
||||||
|
if err != nil {
|
||||||
|
return jose.JSONWebKeySet{}, err
|
||||||
|
}
|
||||||
|
keySet.Keys = append(keySet.Keys, *key)
|
||||||
|
}
|
||||||
|
return keySet, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) {
|
||||||
|
b, err := json.Marshal(jwk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
opts := new(jose.EncrypterOptions)
|
||||||
|
opts.WithContentType(jose.ContentType("jwk+json"))
|
||||||
|
recipient := jose.Recipient{
|
||||||
|
Algorithm: jose.PBES2_HS256_A128KW,
|
||||||
|
Key: []byte("password"),
|
||||||
|
PBES2Count: jose.PBKDF2Iterations,
|
||||||
|
PBES2Salt: salt,
|
||||||
|
}
|
||||||
|
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return encrypter.Encrypt(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decryptJSONWebKey(key string) (*jose.JSONWebKey, error) {
|
||||||
|
enc, err := jose.ParseEncrypted(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
b, err := enc.Decrypt([]byte("password"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jwk := new(jose.JSONWebKey)
|
||||||
|
if err := json.Unmarshal(b, jwk); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return jwk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateJWK() (*JWK, error) {
|
||||||
|
name, err := randutil.Alphanumeric(10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jwk, err := generateJSONWebKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jwe, err := encryptJSONWebKey(jwk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
public := jwk.Public()
|
||||||
|
encrypted, err := jwe.CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &JWK{
|
||||||
|
Name: name,
|
||||||
|
Type: "JWK",
|
||||||
|
Key: &public,
|
||||||
|
EncryptedKey: encrypted,
|
||||||
|
Claims: &globalProvisionerClaims,
|
||||||
|
audiences: testAudiences,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateOIDC() (*OIDC, error) {
|
||||||
|
name, err := randutil.Alphanumeric(10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
clientID, err := randutil.Alphanumeric(10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
issuer, err := randutil.Alphanumeric(10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jwk, err := generateJSONWebKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &OIDC{
|
||||||
|
Name: name,
|
||||||
|
Type: "OIDC",
|
||||||
|
ClientID: clientID,
|
||||||
|
ConfigurationEndpoint: "https://example.com/.well-known/openid-configuration",
|
||||||
|
Claims: &globalProvisionerClaims,
|
||||||
|
configuration: openIDConfiguration{
|
||||||
|
Issuer: issuer,
|
||||||
|
JWKSetURI: "https://example.com/.well-known/jwks",
|
||||||
|
},
|
||||||
|
keyStore: &keyStore{
|
||||||
|
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||||
|
expiry: time.Now().Add(24 * time.Hour),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateCollection(nJWK, nOIDC int) (*Collection, error) {
|
||||||
|
col := NewCollection(testAudiences)
|
||||||
|
for i := 0; i < nJWK; i++ {
|
||||||
|
p, err := generateJWK()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
col.Store(p)
|
||||||
|
}
|
||||||
|
for i := 0; i < nOIDC; i++ {
|
||||||
|
p, err := generateOIDC()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
col.Store(p)
|
||||||
|
}
|
||||||
|
return col, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
|
||||||
|
return generateToken("subject", iss, aud, "name@smallstep.com", []string{"test.smallstep.com"}, time.Now(), jwk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||||
|
sig, err := jose.NewSigner(
|
||||||
|
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||||
|
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := randutil.ASCII(64)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := struct {
|
||||||
|
jose.Claims
|
||||||
|
Email string `json:"email"`
|
||||||
|
SANS []string `json:"sans"`
|
||||||
|
}{
|
||||||
|
Claims: jose.Claims{
|
||||||
|
ID: id,
|
||||||
|
Subject: sub,
|
||||||
|
Issuer: iss,
|
||||||
|
IssuedAt: jose.NewNumericDate(iat),
|
||||||
|
NotBefore: jose.NewNumericDate(iat),
|
||||||
|
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
|
||||||
|
Audience: []string{aud},
|
||||||
|
},
|
||||||
|
Email: email,
|
||||||
|
SANS: sans,
|
||||||
|
}
|
||||||
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
|
||||||
|
tok, err := jose.ParseSigned(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
claims := new(jose.Claims)
|
||||||
|
if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return tok, claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateJWKServer(n int) *httptest.Server {
|
||||||
|
hits := struct {
|
||||||
|
Hits int `json:"hits"`
|
||||||
|
}{}
|
||||||
|
writeJSON := func(w http.ResponseWriter, v interface{}) {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Add("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(b)
|
||||||
|
}
|
||||||
|
getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet {
|
||||||
|
var ret jose.JSONWebKeySet
|
||||||
|
for _, k := range ks.Keys {
|
||||||
|
ret.Keys = append(ret.Keys, k.Public())
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||||
|
srv := httptest.NewUnstartedServer(nil)
|
||||||
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hits.Hits++
|
||||||
|
switch r.RequestURI {
|
||||||
|
case "/error":
|
||||||
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||||
|
case "/hits":
|
||||||
|
writeJSON(w, hits)
|
||||||
|
case "/openid-configuration", "/.well-known/openid-configuration":
|
||||||
|
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
|
||||||
|
case "/random":
|
||||||
|
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||||
|
w.Header().Add("Cache-Control", "max-age=5")
|
||||||
|
writeJSON(w, getPublic(keySet))
|
||||||
|
case "/private":
|
||||||
|
writeJSON(w, defaultKeySet)
|
||||||
|
default:
|
||||||
|
w.Header().Add("Cache-Control", "max-age=5")
|
||||||
|
writeJSON(w, getPublic(defaultKeySet))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
srv.Start()
|
||||||
|
return srv
|
||||||
|
}
|
|
@ -1,55 +0,0 @@
|
||||||
package authority
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
jose "gopkg.in/square/go-jose.v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestProvisionerInit(t *testing.T) {
|
|
||||||
type ProvisionerValidateTest struct {
|
|
||||||
p *Provisioner
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
tests := map[string]func(*testing.T) ProvisionerValidateTest{
|
|
||||||
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
|
|
||||||
return ProvisionerValidateTest{
|
|
||||||
p: &Provisioner{},
|
|
||||||
err: errors.New("provisioner name cannot be empty"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
|
|
||||||
return ProvisionerValidateTest{
|
|
||||||
p: &Provisioner{Name: "foo"},
|
|
||||||
err: errors.New("provisioner type cannot be empty"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail-empty-key": func(t *testing.T) ProvisionerValidateTest {
|
|
||||||
return ProvisionerValidateTest{
|
|
||||||
p: &Provisioner{Name: "foo", Type: "bar"},
|
|
||||||
err: errors.New("provisioner key cannot be empty"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
|
||||||
return ProvisionerValidateTest{
|
|
||||||
p: &Provisioner{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, get := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := get(t)
|
|
||||||
err := tc.p.Init(&globalProvisionerClaims)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
assert.Equals(t, tc.err.Error(), err.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert.Nil(t, tc.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,115 +1,25 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
|
||||||
const DefaultProvisionersLimit = 20
|
|
||||||
|
|
||||||
// DefaultProvisionersMax is the maximum limit for listing provisioners.
|
|
||||||
const DefaultProvisionersMax = 100
|
|
||||||
|
|
||||||
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
||||||
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||||
val, ok := a.encryptedKeyIndex.Load(kid)
|
key, ok := a.provisioners.LoadEncryptedKey(kid)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
|
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
|
||||||
http.StatusNotFound, context{}}
|
http.StatusNotFound, context{}}
|
||||||
}
|
}
|
||||||
|
|
||||||
key, ok := val.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", &apiError{errors.Errorf("stored value is not a string"),
|
|
||||||
http.StatusInternalServerError, context{}}
|
|
||||||
}
|
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProvisioners returns a map listing each provisioner and the JWK Key Set
|
// GetProvisioners returns a map listing each provisioner and the JWK Key Set
|
||||||
// with their public keys.
|
// with their public keys.
|
||||||
func (a *Authority) GetProvisioners(cursor string, limit int) ([]*Provisioner, string, error) {
|
func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) {
|
||||||
provisioners, nextCursor := a.sortedProvisioners.Find(cursor, limit)
|
provisioners, nextCursor := a.provisioners.Find(cursor, limit)
|
||||||
return provisioners, nextCursor, nil
|
return provisioners, nextCursor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type uidProvisioner struct {
|
|
||||||
provisioner *Provisioner
|
|
||||||
uid string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSortedProvisioners(provisioners []*Provisioner) (provisionerSlice, error) {
|
|
||||||
if len(provisioners) > math.MaxInt32 {
|
|
||||||
return nil, errors.New("too many provisioners")
|
|
||||||
}
|
|
||||||
|
|
||||||
var slice provisionerSlice
|
|
||||||
bi := make([]byte, 4)
|
|
||||||
for i, p := range provisioners {
|
|
||||||
sum, err := provisionerSum(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Use the first 4 bytes (32bit) of the sum to insert the order
|
|
||||||
// Using big endian format to get the strings sorted:
|
|
||||||
// 0x00000000, 0x00000001, 0x00000002, ...
|
|
||||||
binary.BigEndian.PutUint32(bi, uint32(i))
|
|
||||||
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
|
||||||
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
|
|
||||||
slice = append(slice, uidProvisioner{
|
|
||||||
provisioner: p,
|
|
||||||
uid: hex.EncodeToString(sum),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
sort.Sort(slice)
|
|
||||||
return slice, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type provisionerSlice []uidProvisioner
|
|
||||||
|
|
||||||
func (p provisionerSlice) Len() int { return len(p) }
|
|
||||||
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
|
||||||
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
|
||||||
|
|
||||||
func (p provisionerSlice) Find(cursor string, limit int) ([]*Provisioner, string) {
|
|
||||||
switch {
|
|
||||||
case limit <= 0:
|
|
||||||
limit = DefaultProvisionersLimit
|
|
||||||
case limit > DefaultProvisionersMax:
|
|
||||||
limit = DefaultProvisionersMax
|
|
||||||
}
|
|
||||||
|
|
||||||
n := len(p)
|
|
||||||
cursor = fmt.Sprintf("%040s", cursor)
|
|
||||||
i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor })
|
|
||||||
|
|
||||||
var slice []*Provisioner
|
|
||||||
for ; i < n && len(slice) < limit; i++ {
|
|
||||||
slice = append(slice, p[i].provisioner)
|
|
||||||
}
|
|
||||||
if i < n {
|
|
||||||
return slice, strings.TrimLeft(p[i].uid, "0")
|
|
||||||
}
|
|
||||||
return slice, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// provisionerSum returns the SHA1 of the json representation of the
|
|
||||||
// provisioner. From this we will create the unique and sorted id.
|
|
||||||
func provisionerSum(p *Provisioner) ([]byte, error) {
|
|
||||||
b, err := json.Marshal(p.Key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error marshalling provisioner")
|
|
||||||
}
|
|
||||||
sum := sha1.Sum(b)
|
|
||||||
return sum[:], nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,16 +1,12 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/cli/crypto/randutil"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/cli/jose"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetEncryptedKey(t *testing.T) {
|
func TestGetEncryptedKey(t *testing.T) {
|
||||||
|
@ -27,7 +23,7 @@ func TestGetEncryptedKey(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return &ek{
|
return &ek{
|
||||||
a: a,
|
a: a,
|
||||||
kid: c.AuthorityConfig.Provisioners[1].Key.KeyID,
|
kid: c.AuthorityConfig.Provisioners[1].(*provisioner.JWK).Key.KeyID,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail-not-found": func(t *testing.T) *ek {
|
"fail-not-found": func(t *testing.T) *ek {
|
||||||
|
@ -42,19 +38,6 @@ func TestGetEncryptedKey(t *testing.T) {
|
||||||
http.StatusNotFound, context{}},
|
http.StatusNotFound, context{}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail-invalid-type-found": func(t *testing.T) *ek {
|
|
||||||
c, err := LoadConfiguration("../ca/testdata/ca.json")
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
a, err := New(c)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
a.encryptedKeyIndex.Store("foo", 5)
|
|
||||||
return &ek{
|
|
||||||
a: a,
|
|
||||||
kid: "foo",
|
|
||||||
err: &apiError{errors.Errorf("stored value is not a string"),
|
|
||||||
http.StatusInternalServerError, context{}},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, genTestCase := range tests {
|
for name, genTestCase := range tests {
|
||||||
|
@ -75,9 +58,9 @@ func TestGetEncryptedKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if assert.Nil(t, tc.err) {
|
if assert.Nil(t, tc.err) {
|
||||||
val, ok := tc.a.provisionerIDIndex.Load("max:" + tc.kid)
|
val, ok := tc.a.provisioners.Load("max:" + tc.kid)
|
||||||
assert.Fatal(t, ok)
|
assert.Fatal(t, ok)
|
||||||
p, ok := val.(*Provisioner)
|
p, ok := val.(*provisioner.JWK)
|
||||||
assert.Fatal(t, ok)
|
assert.Fatal(t, ok)
|
||||||
assert.Equals(t, p.EncryptedKey, ek)
|
assert.Equals(t, p.EncryptedKey, ek)
|
||||||
}
|
}
|
||||||
|
@ -126,102 +109,3 @@ func TestGetProvisioners(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateProvisioner(t *testing.T) *Provisioner {
|
|
||||||
name, err := randutil.Alphanumeric(10)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
// Create a new JWK
|
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
// Encrypt JWK
|
|
||||||
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(jwk)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
recipient := jose.Recipient{
|
|
||||||
Algorithm: jose.PBES2_HS256_A128KW,
|
|
||||||
Key: []byte("password"),
|
|
||||||
PBES2Count: jose.PBKDF2Iterations,
|
|
||||||
PBES2Salt: salt,
|
|
||||||
}
|
|
||||||
opts := new(jose.EncrypterOptions)
|
|
||||||
opts.WithContentType(jose.ContentType("jwk+json"))
|
|
||||||
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
jwe, err := encrypter.Encrypt(b)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
// get public and encrypted keys
|
|
||||||
public := jwk.Public()
|
|
||||||
encrypted, err := jwe.CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return &Provisioner{
|
|
||||||
Name: name,
|
|
||||||
Type: "JWT",
|
|
||||||
Key: &public,
|
|
||||||
EncryptedKey: encrypted,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_newSortedProvisioners(t *testing.T) {
|
|
||||||
provisioners := make([]*Provisioner, 20)
|
|
||||||
for i := range provisioners {
|
|
||||||
provisioners[i] = generateProvisioner(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
ps, err := newSortedProvisioners(provisioners)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
prev := ""
|
|
||||||
for i, p := range ps {
|
|
||||||
if p.uid < prev {
|
|
||||||
t.Errorf("%s should be less that %s", p.uid, prev)
|
|
||||||
}
|
|
||||||
if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID {
|
|
||||||
t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID)
|
|
||||||
}
|
|
||||||
prev = p.uid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_provisionerSlice_Find(t *testing.T) {
|
|
||||||
trim := func(s string) string {
|
|
||||||
return strings.TrimLeft(s, "0")
|
|
||||||
}
|
|
||||||
provisioners := make([]*Provisioner, 20)
|
|
||||||
for i := range provisioners {
|
|
||||||
provisioners[i] = generateProvisioner(t)
|
|
||||||
}
|
|
||||||
ps, err := newSortedProvisioners(provisioners)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
cursor string
|
|
||||||
limit int
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
p provisionerSlice
|
|
||||||
args args
|
|
||||||
want []*Provisioner
|
|
||||||
want1 string
|
|
||||||
}{
|
|
||||||
{"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""},
|
|
||||||
{"0 to 19", ps, args{"", 20}, provisioners[0:20], ""},
|
|
||||||
{"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)},
|
|
||||||
{"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""},
|
|
||||||
{"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)},
|
|
||||||
{"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)},
|
|
||||||
{"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""},
|
|
||||||
{"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit)
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
if got1 != tt.want1 {
|
|
||||||
t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ package authority
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -11,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/cli/crypto/pemutil"
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
"github.com/smallstep/cli/crypto/tlsutil"
|
"github.com/smallstep/cli/crypto/tlsutil"
|
||||||
"github.com/smallstep/cli/crypto/x509util"
|
"github.com/smallstep/cli/crypto/x509util"
|
||||||
|
@ -22,48 +22,7 @@ func (a *Authority) GetTLSOptions() *tlsutil.TLSOptions {
|
||||||
return a.config.TLS
|
return a.config.TLS
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOptions contains the options that can be passed to the Authority.Sign
|
var oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35}
|
||||||
// method.
|
|
||||||
type SignOptions struct {
|
|
||||||
NotAfter time.Time `json:"notAfter"`
|
|
||||||
NotBefore time.Time `json:"notBefore"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
|
||||||
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
|
||||||
oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35}
|
|
||||||
)
|
|
||||||
|
|
||||||
type stepProvisionerASN1 struct {
|
|
||||||
Type int
|
|
||||||
Name []byte
|
|
||||||
CredentialID []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
const provisionerTypeJWK = 1
|
|
||||||
|
|
||||||
func withProvisionerOID(name, kid string) x509util.WithOption {
|
|
||||||
return func(p x509util.Profile) error {
|
|
||||||
crt := p.Subject()
|
|
||||||
|
|
||||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
|
||||||
Type: provisionerTypeJWK,
|
|
||||||
Name: []byte(name),
|
|
||||||
CredentialID: []byte(kid),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{
|
|
||||||
Id: stepOIDProvisioner,
|
|
||||||
Critical: false,
|
|
||||||
Value: b,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
|
func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
|
||||||
return func(p x509util.Profile) error {
|
return func(p x509util.Profile) error {
|
||||||
|
@ -96,28 +55,22 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sign creates a signed certificate from a certificate signing request.
|
// Sign creates a signed certificate from a certificate signing request.
|
||||||
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) {
|
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
|
||||||
var (
|
var (
|
||||||
errContext = context{"csr": csr, "signOptions": signOpts}
|
errContext = context{"csr": csr, "signOptions": signOpts}
|
||||||
claims = []certClaim{}
|
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
|
||||||
mods = []x509util.WithOption{}
|
certValidators = []provisioner.CertificateValidator{}
|
||||||
)
|
)
|
||||||
for _, op := range extraOpts {
|
for _, op := range extraOpts {
|
||||||
switch k := op.(type) {
|
switch k := op.(type) {
|
||||||
case certClaim:
|
case provisioner.CertificateValidator:
|
||||||
claims = append(claims, k)
|
certValidators = append(certValidators, k)
|
||||||
case x509util.WithOption:
|
case provisioner.CertificateRequestValidator:
|
||||||
mods = append(mods, k)
|
if err := k.Valid(csr); err != nil {
|
||||||
case *Provisioner:
|
return nil, nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
|
||||||
m, c, err := k.getTLSApps(signOpts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, &apiError{err, http.StatusInternalServerError, errContext}
|
|
||||||
}
|
}
|
||||||
mods = append(mods, m...)
|
case provisioner.ProfileModifier:
|
||||||
mods = append(mods, []x509util.WithOption{
|
mods = append(mods, k.Option(signOpts))
|
||||||
withDefaultASN1DN(a.config.AuthorityConfig.Template),
|
|
||||||
}...)
|
|
||||||
claims = append(claims, c...)
|
|
||||||
default:
|
default:
|
||||||
return nil, nil, &apiError{errors.Errorf("sign: invalid extra option type %T", k),
|
return nil, nil, &apiError{errors.Errorf("sign: invalid extra option type %T", k),
|
||||||
http.StatusInternalServerError, errContext}
|
http.StatusInternalServerError, errContext}
|
||||||
|
@ -137,10 +90,6 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, ext
|
||||||
return nil, nil, &apiError{errors.Wrapf(err, "sign"), http.StatusInternalServerError, errContext}
|
return nil, nil, &apiError{errors.Wrapf(err, "sign"), http.StatusInternalServerError, errContext}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateClaims(leaf.Subject(), claims); err != nil {
|
|
||||||
return nil, nil, &apiError{errors.Wrapf(err, "sign"), http.StatusUnauthorized, errContext}
|
|
||||||
}
|
|
||||||
|
|
||||||
crtBytes, err := leaf.CreateCertificate()
|
crtBytes, err := leaf.CreateCertificate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, &apiError{errors.Wrap(err, "sign: error creating new leaf certificate"),
|
return nil, nil, &apiError{errors.Wrap(err, "sign: error creating new leaf certificate"),
|
||||||
|
@ -153,6 +102,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, ext
|
||||||
http.StatusInternalServerError, errContext}
|
http.StatusInternalServerError, errContext}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME: This should be before creating the certificate.
|
||||||
|
for _, v := range certValidators {
|
||||||
|
if err := v.Valid(serverCert); err != nil {
|
||||||
|
return nil, nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
|
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, &apiError{errors.Wrap(err, "sign: error parsing intermediate certificate"),
|
return nil, nil, &apiError{errors.Wrap(err, "sign: error parsing intermediate certificate"),
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -15,12 +14,49 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/cli/crypto/keys"
|
"github.com/smallstep/cli/crypto/keys"
|
||||||
"github.com/smallstep/cli/crypto/tlsutil"
|
"github.com/smallstep/cli/crypto/tlsutil"
|
||||||
"github.com/smallstep/cli/crypto/x509util"
|
"github.com/smallstep/cli/crypto/x509util"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
stepx509 "github.com/smallstep/cli/pkg/x509"
|
stepx509 "github.com/smallstep/cli/pkg/x509"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||||
|
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
||||||
|
)
|
||||||
|
|
||||||
|
const provisionerTypeJWK = 1
|
||||||
|
|
||||||
|
type stepProvisionerASN1 struct {
|
||||||
|
Type int
|
||||||
|
Name []byte
|
||||||
|
CredentialID []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func withProvisionerOID(name, kid string) x509util.WithOption {
|
||||||
|
return func(p x509util.Profile) error {
|
||||||
|
crt := p.Subject()
|
||||||
|
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||||
|
Type: provisionerTypeJWK,
|
||||||
|
Name: []byte(name),
|
||||||
|
CredentialID: []byte(kid),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: stepOIDProvisioner,
|
||||||
|
Critical: false,
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateRequest)) *x509.CertificateRequest {
|
func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateRequest)) *x509.CertificateRequest {
|
||||||
_csr := &x509.CertificateRequest{
|
_csr := &x509.CertificateRequest{
|
||||||
Subject: pkix.Name{CommonName: "smallstep test"},
|
Subject: pkix.Name{CommonName: "smallstep test"},
|
||||||
|
@ -52,24 +88,25 @@ func TestSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
nb := time.Now()
|
nb := time.Now()
|
||||||
signOpts := SignOptions{
|
signOpts := provisioner.Options{
|
||||||
NotBefore: nb,
|
NotBefore: nb,
|
||||||
NotAfter: nb.Add(time.Minute * 5),
|
NotAfter: nb.Add(time.Minute * 5),
|
||||||
}
|
}
|
||||||
|
|
||||||
p := a.config.AuthorityConfig.Provisioners[1]
|
// Create a token to get test extra opts.
|
||||||
extraOpts := []interface{}{
|
p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK)
|
||||||
&commonNameClaim{"smallstep test"},
|
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||||
&dnsNamesClaim{[]string{"test.smallstep.com"}},
|
assert.FatalError(t, err)
|
||||||
&ipAddressesClaim{[]net.IP{}},
|
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
|
||||||
p,
|
assert.FatalError(t, err)
|
||||||
}
|
extraOpts, err := a.Authorize(token)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type signTest struct {
|
type signTest struct {
|
||||||
auth *Authority
|
auth *Authority
|
||||||
csr *x509.CertificateRequest
|
csr *x509.CertificateRequest
|
||||||
signOpts SignOptions
|
signOpts provisioner.Options
|
||||||
extraOpts []interface{}
|
extraOpts []provisioner.SignOption
|
||||||
err *apiError
|
err *apiError
|
||||||
}
|
}
|
||||||
tests := map[string]func(*testing.T) *signTest{
|
tests := map[string]func(*testing.T) *signTest{
|
||||||
|
@ -123,7 +160,7 @@ func TestSign(t *testing.T) {
|
||||||
return &signTest{
|
return &signTest{
|
||||||
auth: _a,
|
auth: _a,
|
||||||
csr: csr,
|
csr: csr,
|
||||||
extraOpts: []interface{}{p},
|
extraOpts: extraOpts,
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: &apiError{errors.New("sign: error creating new leaf certificate"),
|
err: &apiError{errors.New("sign: error creating new leaf certificate"),
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
|
@ -133,7 +170,7 @@ func TestSign(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail provisioner duration claim": func(t *testing.T) *signTest {
|
"fail provisioner duration claim": func(t *testing.T) *signTest {
|
||||||
csr := getCSR(t, priv)
|
csr := getCSR(t, priv)
|
||||||
_signOpts := SignOptions{
|
_signOpts := provisioner.Options{
|
||||||
NotBefore: nb,
|
NotBefore: nb,
|
||||||
NotAfter: nb.Add(time.Hour * 25),
|
NotAfter: nb.Add(time.Hour * 25),
|
||||||
}
|
}
|
||||||
|
@ -157,7 +194,7 @@ func TestSign(t *testing.T) {
|
||||||
csr: csr,
|
csr: csr,
|
||||||
extraOpts: extraOpts,
|
extraOpts: extraOpts,
|
||||||
signOpts: signOpts,
|
signOpts: signOpts,
|
||||||
err: &apiError{errors.New("sign: DNS names claim failed - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
||||||
http.StatusUnauthorized,
|
http.StatusUnauthorized,
|
||||||
context{"csr": csr, "signOptions": signOpts},
|
context{"csr": csr, "signOptions": signOpts},
|
||||||
},
|
},
|
||||||
|
@ -262,7 +299,7 @@ func TestRenew(t *testing.T) {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
nb1 := now.Add(-time.Minute * 7)
|
nb1 := now.Add(-time.Minute * 7)
|
||||||
na1 := now
|
na1 := now
|
||||||
so := &SignOptions{
|
so := &provisioner.Options{
|
||||||
NotBefore: nb1,
|
NotBefore: nb1,
|
||||||
NotAfter: na1,
|
NotAfter: na1,
|
||||||
}
|
}
|
||||||
|
@ -272,7 +309,7 @@ func TestRenew(t *testing.T) {
|
||||||
x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0),
|
x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0),
|
||||||
withDefaultASN1DN(a.config.AuthorityConfig.Template),
|
withDefaultASN1DN(a.config.AuthorityConfig.Template),
|
||||||
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
|
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
|
||||||
withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].Key.KeyID))
|
withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID))
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
crtBytes, err := leaf.CreateCertificate()
|
crtBytes, err := leaf.CreateCertificate()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
@ -284,7 +321,7 @@ func TestRenew(t *testing.T) {
|
||||||
x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0),
|
x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0),
|
||||||
withDefaultASN1DN(a.config.AuthorityConfig.Template),
|
withDefaultASN1DN(a.config.AuthorityConfig.Template),
|
||||||
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
|
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
|
||||||
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID),
|
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID),
|
||||||
)
|
)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
crtBytesNoRenew, err := leafNoRenew.CreateCertificate()
|
crtBytesNoRenew, err := leafNoRenew.CreateCertificate()
|
||||||
|
@ -321,7 +358,7 @@ func TestRenew(t *testing.T) {
|
||||||
}
|
}
|
||||||
return &renewTest{
|
return &renewTest{
|
||||||
crt: crtNoRenew,
|
crt: crtNoRenew,
|
||||||
err: &apiError{errors.New("renew disabled"),
|
err: &apiError{errors.New("renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
|
||||||
http.StatusUnauthorized, ctx},
|
http.StatusUnauthorized, ctx},
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
|
|
@ -2,48 +2,10 @@ package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
|
|
||||||
type Duration struct {
|
|
||||||
time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalJSON parses a duration string and sets it to the duration.
|
|
||||||
//
|
|
||||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
|
||||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
|
||||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
|
||||||
func (d *Duration) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(d.Duration.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalJSON parses a duration string and sets it to the duration.
|
|
||||||
//
|
|
||||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
|
||||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
|
||||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
|
||||||
func (d *Duration) UnmarshalJSON(data []byte) (err error) {
|
|
||||||
var (
|
|
||||||
s string
|
|
||||||
_d time.Duration
|
|
||||||
)
|
|
||||||
if d == nil {
|
|
||||||
return errors.New("duration cannot be nil")
|
|
||||||
}
|
|
||||||
if err = json.Unmarshal(data, &s); err != nil {
|
|
||||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
|
||||||
}
|
|
||||||
if _d, err = time.ParseDuration(s); err != nil {
|
|
||||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
|
||||||
}
|
|
||||||
d.Duration = _d
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// multiString represents a type that can be encoded/decoded in JSON as a single
|
// multiString represents a type that can be encoded/decoded in JSON as a single
|
||||||
// string or an array of strings.
|
// string or an array of strings.
|
||||||
type multiString []string
|
type multiString []string
|
||||||
|
|
|
@ -3,7 +3,6 @@ package authority
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_multiString_First(t *testing.T) {
|
func Test_multiString_First(t *testing.T) {
|
||||||
|
@ -101,57 +100,3 @@ func Test_multiString_UnmarshalJSON(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDuration_UnmarshalJSON(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
data []byte
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
d *Duration
|
|
||||||
args args
|
|
||||||
want *Duration
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"empty", new(Duration), args{[]byte{}}, new(Duration), true},
|
|
||||||
{"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true},
|
|
||||||
{"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true},
|
|
||||||
{"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true},
|
|
||||||
{"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false},
|
|
||||||
{"nil", nil, args{nil}, nil, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(tt.d, tt.want) {
|
|
||||||
t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_duration_MarshalJSON(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
d *Duration
|
|
||||||
want []byte
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := tt.d.MarshalJSON()
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/authority"
|
provisioners "github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/ca"
|
"github.com/smallstep/certificates/ca"
|
||||||
"github.com/smallstep/cli/config"
|
"github.com/smallstep/cli/config"
|
||||||
"github.com/smallstep/cli/crypto/randutil"
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
@ -111,13 +111,15 @@ func loadProvisionerJWKByName(name, caURL, caRoot, passFile string) (key *jose.J
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, provisioner := range provisioners {
|
for _, provisioner := range provisioners {
|
||||||
if provisioner.Name == name {
|
if provisioner.GetName() == name {
|
||||||
key, err = decryptProvisionerJWK(provisioner.EncryptedKey, passFile)
|
if _, encryptedKey, ok := provisioner.GetEncryptedKey(); ok {
|
||||||
|
key, err = decryptProvisionerJWK(encryptedKey, passFile)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil, errors.Errorf("provisioner '%s' not found (or your password is wrong)", name)
|
return nil, errors.Errorf("provisioner '%s' not found (or your password is wrong)", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +156,7 @@ func getRootCAPath() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// getProvisioners returns the map of provisioners on the given CA.
|
// getProvisioners returns the map of provisioners on the given CA.
|
||||||
func getProvisioners(caURL, rootFile string) ([]*authority.Provisioner, error) {
|
func getProvisioners(caURL, rootFile string) (provisioners.List, error) {
|
||||||
if len(rootFile) == 0 {
|
if len(rootFile) == 0 {
|
||||||
rootFile = getRootCAPath()
|
rootFile = getRootCAPath()
|
||||||
}
|
}
|
||||||
|
@ -163,7 +165,7 @@ func getProvisioners(caURL, rootFile string) ([]*authority.Provisioner, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
cursor := ""
|
cursor := ""
|
||||||
provisioners := []*authority.Provisioner{}
|
var provisioners provisioners.List
|
||||||
for {
|
for {
|
||||||
resp, err := client.Provisioners(ca.WithProvisionerCursor(cursor), ca.WithProvisionerLimit(100))
|
resp, err := client.Provisioners(ca.WithProvisionerCursor(cursor), ca.WithProvisionerLimit(100))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/cli/crypto/keys"
|
"github.com/smallstep/cli/crypto/keys"
|
||||||
"github.com/smallstep/cli/crypto/pemutil"
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
"github.com/smallstep/cli/crypto/randutil"
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
@ -389,7 +390,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) *ekt {
|
"ok": func(t *testing.T) *ekt {
|
||||||
p := config.AuthorityConfig.Provisioners[2]
|
p := config.AuthorityConfig.Provisioners[2].(*provisioner.JWK)
|
||||||
return &ekt{
|
return &ekt{
|
||||||
ca: ca,
|
ca: ca,
|
||||||
kid: p.Key.KeyID,
|
kid: p.Key.KeyID,
|
||||||
|
|
|
@ -446,7 +446,11 @@ func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error)
|
||||||
return nil, nil, errors.Wrap(err, "error generating key")
|
return nil, nil, errors.Wrap(err, "error generating key")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var emails []string
|
||||||
dnsNames, ips := x509util.SplitSANs(claims.SANs)
|
dnsNames, ips := x509util.SplitSANs(claims.SANs)
|
||||||
|
if claims.Email != "" {
|
||||||
|
emails = append(emails, claims.Email)
|
||||||
|
}
|
||||||
|
|
||||||
template := &x509.CertificateRequest{
|
template := &x509.CertificateRequest{
|
||||||
Subject: pkix.Name{
|
Subject: pkix.Name{
|
||||||
|
@ -455,6 +459,7 @@ func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error)
|
||||||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||||
DNSNames: dnsNames,
|
DNSNames: dnsNames,
|
||||||
IPAddresses: ips,
|
IPAddresses: ips,
|
||||||
|
EmailAddresses: emails,
|
||||||
}
|
}
|
||||||
|
|
||||||
csr, err := x509.CreateCertificateRequest(rand.Reader, template, pk)
|
csr, err := x509.CreateCertificateRequest(rand.Reader, template, pk)
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -391,7 +391,7 @@ func TestClient_Renew(t *testing.T) {
|
||||||
|
|
||||||
func TestClient_Provisioners(t *testing.T) {
|
func TestClient_Provisioners(t *testing.T) {
|
||||||
ok := &api.ProvisionersResponse{
|
ok := &api.ProvisionersResponse{
|
||||||
Provisioners: []*authority.Provisioner{},
|
Provisioners: provisioner.List{},
|
||||||
}
|
}
|
||||||
internalServerError := api.InternalServerError(fmt.Errorf("Internal Server Error"))
|
internalServerError := api.InternalServerError(fmt.Errorf("Internal Server Error"))
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
@ -143,6 +144,25 @@ intermediate private key.`,
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Action = func(ctx *cli.Context) error {
|
app.Action = func(ctx *cli.Context) error {
|
||||||
|
// Hack to be able to run a the top action as a subcommand
|
||||||
|
cmd := cli.Command{Name: "start", Action: startAction, Flags: app.Flags}
|
||||||
|
set := flag.NewFlagSet(app.Name, flag.ContinueOnError)
|
||||||
|
set.Parse(os.Args)
|
||||||
|
ctx = cli.NewContext(app, set, nil)
|
||||||
|
return cmd.Run(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := app.Run(os.Args); err != nil {
|
||||||
|
if os.Getenv("STEPDEBUG") == "1" {
|
||||||
|
fmt.Fprintf(os.Stderr, "%+v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
}
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startAction(ctx *cli.Context) error {
|
||||||
passFile := ctx.String("password-file")
|
passFile := ctx.String("password-file")
|
||||||
|
|
||||||
// If zero cmd line args show help, if >1 cmd line args show error.
|
// If zero cmd line args show help, if >1 cmd line args show error.
|
||||||
|
@ -179,16 +199,6 @@ intermediate private key.`,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := app.Run(os.Args); err != nil {
|
|
||||||
if os.Getenv("STEPDEBUG") == "1" {
|
|
||||||
fmt.Fprintf(os.Stderr, "%+v\n", err)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintln(os.Stderr, err)
|
|
||||||
}
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// fatal writes the passed error on the standard error and exits with the exit
|
// fatal writes the passed error on the standard error and exits with the exit
|
||||||
// code 1. If the environment variable STEPDEBUG is set to 1 it shows the
|
// code 1. If the environment variable STEPDEBUG is set to 1 it shows the
|
||||||
// stack trace of the error.
|
// stack trace of the error.
|
||||||
|
|
Loading…
Reference in a new issue