Merge pull request #51 from smallstep/oidc-provisioner
OIDC provisioners
This commit is contained in:
commit
095ab891e7
40 changed files with 3080 additions and 1234 deletions
20
api/api.go
20
api/api.go
|
@ -18,19 +18,19 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
)
|
||||
|
||||
// Authority is the interface implemented by a CA authority.
|
||||
type Authority interface {
|
||||
Authorize(ott string) ([]interface{}, error)
|
||||
Authorize(ott string) ([]provisioner.SignOption, error)
|
||||
GetTLSOptions() *tlsutil.TLSOptions
|
||||
Root(shasum string) (*x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
|
||||
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
||||
GetEncryptedKey(kid string) (string, error)
|
||||
GetRoots() (federation []*x509.Certificate, err error)
|
||||
GetFederation() ([]*x509.Certificate, error)
|
||||
|
@ -161,11 +161,11 @@ type SignRequest struct {
|
|||
// ProvisionersResponse is the response object that returns the list of
|
||||
// provisioners.
|
||||
type ProvisionersResponse struct {
|
||||
Provisioners []*authority.Provisioner `json:"provisioners"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
Provisioners provisioner.List `json:"provisioners"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}
|
||||
|
||||
// ProvisionerKeyResponse is the response object that returns the encryptoed key
|
||||
// ProvisionerKeyResponse is the response object that returns the encrypted key
|
||||
// of a provisioner.
|
||||
type ProvisionerKeyResponse struct {
|
||||
Key string `json:"key"`
|
||||
|
@ -266,18 +266,18 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
signOpts := authority.SignOptions{
|
||||
opts := provisioner.Options{
|
||||
NotBefore: body.NotBefore,
|
||||
NotAfter: body.NotAfter,
|
||||
}
|
||||
|
||||
extraOpts, err := h.Authority.Authorize(body.OTT)
|
||||
signOpts, err := h.Authority.Authorize(body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, Unauthorized(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, root, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, signOpts, extraOpts...)
|
||||
cert, root, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
|
|
|
@ -24,7 +24,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
|
@ -410,22 +410,22 @@ func TestSignRequest_Validate(t *testing.T) {
|
|||
type mockAuthority struct {
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
authorize func(ott string) ([]interface{}, error)
|
||||
authorize func(ott string) ([]provisioner.SignOption, error)
|
||||
getTLSOptions func() *tlsutil.TLSOptions
|
||||
root func(shasum string) (*x509.Certificate, error)
|
||||
sign func(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
|
||||
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
|
||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
getEncryptedKey func(kid string) (string, error)
|
||||
getRoots func() ([]*x509.Certificate, error)
|
||||
getFederation func() ([]*x509.Certificate, error)
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) {
|
||||
func (m *mockAuthority) Authorize(ott string) ([]provisioner.SignOption, error) {
|
||||
if m.authorize != nil {
|
||||
return m.authorize(ott)
|
||||
}
|
||||
return m.ret1.([]interface{}), m.err
|
||||
return m.ret1.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetTLSOptions() *tlsutil.TLSOptions {
|
||||
|
@ -442,9 +442,9 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
|
|||
return m.ret1.(*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) {
|
||||
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
|
||||
if m.sign != nil {
|
||||
return m.sign(cr, signOpts, extraOpts...)
|
||||
return m.sign(cr, opts, signOpts...)
|
||||
}
|
||||
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
||||
}
|
||||
|
@ -456,11 +456,11 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.
|
|||
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) ([]*authority.Provisioner, string, error) {
|
||||
func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) {
|
||||
if m.getProvisioners != nil {
|
||||
return m.getProvisioners(nextCursor, limit)
|
||||
}
|
||||
return m.ret1.([]*authority.Provisioner), m.ret2.(string), m.err
|
||||
return m.ret1.(provisioner.List), m.ret2.(string), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
|
||||
|
@ -597,7 +597,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
certAttrOpts []interface{}
|
||||
certAttrOpts []provisioner.SignOption
|
||||
autherr error
|
||||
cert *x509.Certificate
|
||||
root *x509.Certificate
|
||||
|
@ -617,7 +617,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
||||
authorize: func(ott string) ([]interface{}, error) {
|
||||
authorize: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return tt.certAttrOpts, tt.autherr
|
||||
},
|
||||
getTLSOptions: func() *tlsutil.TLSOptions {
|
||||
|
@ -723,14 +723,14 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p := []*authority.Provisioner{
|
||||
{
|
||||
p := provisioner.List{
|
||||
&provisioner.JWK{
|
||||
Type: "JWK",
|
||||
Name: "max",
|
||||
EncryptedKey: "abc",
|
||||
Key: &key,
|
||||
},
|
||||
{
|
||||
&provisioner.JWK{
|
||||
Type: "JWK",
|
||||
Name: "mariano",
|
||||
EncryptedKey: "def",
|
||||
|
|
|
@ -4,10 +4,10 @@ import (
|
|||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
)
|
||||
|
@ -16,18 +16,14 @@ const legacyAuthority = "step-certificate-authority"
|
|||
|
||||
// Authority implements the Certificate Authority internal interface.
|
||||
type Authority struct {
|
||||
config *Config
|
||||
rootX509Certs []*x509.Certificate
|
||||
intermediateIdentity *x509util.Identity
|
||||
validateOnce bool
|
||||
certificates *sync.Map
|
||||
ottMap *sync.Map
|
||||
startTime time.Time
|
||||
provisionerIDIndex *sync.Map
|
||||
encryptedKeyIndex *sync.Map
|
||||
provisionerKeySetIndex *sync.Map
|
||||
sortedProvisioners provisionerSlice
|
||||
audiences []string
|
||||
config *Config
|
||||
rootX509Certs []*x509.Certificate
|
||||
intermediateIdentity *x509util.Identity
|
||||
validateOnce bool
|
||||
certificates *sync.Map
|
||||
ottMap *sync.Map
|
||||
startTime time.Time
|
||||
provisioners *provisioner.Collection
|
||||
// Do not re-initialize
|
||||
initOnce bool
|
||||
}
|
||||
|
@ -39,31 +35,11 @@ func New(config *Config) (*Authority, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Get sorted provisioners
|
||||
var sorted provisionerSlice
|
||||
if config.AuthorityConfig != nil {
|
||||
sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Define audiences: legacy + possible urls without the ports.
|
||||
// The CA might have proxies in front so we cannot rely on the port.
|
||||
audiences := []string{legacyAuthority}
|
||||
for _, name := range config.DNSNames {
|
||||
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
|
||||
}
|
||||
|
||||
var a = &Authority{
|
||||
config: config,
|
||||
certificates: new(sync.Map),
|
||||
ottMap: new(sync.Map),
|
||||
provisionerIDIndex: new(sync.Map),
|
||||
encryptedKeyIndex: new(sync.Map),
|
||||
provisionerKeySetIndex: new(sync.Map),
|
||||
sortedProvisioners: sorted,
|
||||
audiences: audiences,
|
||||
config: config,
|
||||
certificates: new(sync.Map),
|
||||
ottMap: new(sync.Map),
|
||||
provisioners: provisioner.NewCollection(config.getAudiences()),
|
||||
}
|
||||
if err := a.init(); err != nil {
|
||||
return nil, err
|
||||
|
@ -120,14 +96,15 @@ func (a *Authority) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Store all the provisioners
|
||||
for _, p := range a.config.AuthorityConfig.Provisioners {
|
||||
a.provisionerIDIndex.Store(p.ID(), p)
|
||||
if len(p.EncryptedKey) != 0 {
|
||||
a.encryptedKeyIndex.Store(p.Key.KeyID, p.EncryptedKey)
|
||||
if err := a.provisioners.Store(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
a.startTime = time.Now()
|
||||
// JWT numeric dates are seconds.
|
||||
a.startTime = time.Now().Truncate(time.Second)
|
||||
// Set flag indicating that initialization has been completed, and should
|
||||
// not be repeated.
|
||||
a.initOnce = true
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
stepJOSE "github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
|
@ -16,22 +17,22 @@ func testAuthority(t *testing.T) *Authority {
|
|||
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
disableRenewal := true
|
||||
p := []*Provisioner{
|
||||
{
|
||||
p := provisioner.List{
|
||||
&provisioner.JWK{
|
||||
Name: "Max",
|
||||
Type: "JWK",
|
||||
Key: maxjwk,
|
||||
},
|
||||
{
|
||||
&provisioner.JWK{
|
||||
Name: "step-cli",
|
||||
Type: "JWK",
|
||||
Key: clijwk,
|
||||
},
|
||||
{
|
||||
&provisioner.JWK{
|
||||
Name: "dev",
|
||||
Type: "JWK",
|
||||
Key: maxjwk,
|
||||
Claims: &ProvisionerClaims{
|
||||
Claims: &provisioner.Claims{
|
||||
DisableRenewal: &disableRenewal,
|
||||
},
|
||||
},
|
||||
|
@ -113,24 +114,18 @@ func TestAuthorityNew(t *testing.T) {
|
|||
assert.True(t, auth.initOnce)
|
||||
assert.NotNil(t, auth.intermediateIdentity)
|
||||
for _, p := range tc.config.AuthorityConfig.Provisioners {
|
||||
_p, ok := auth.provisionerIDIndex.Load(p.ID())
|
||||
_p, ok := auth.provisioners.Load(p.GetID())
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, p, _p)
|
||||
if len(p.EncryptedKey) > 0 {
|
||||
key, ok := auth.encryptedKeyIndex.Load(p.Key.KeyID)
|
||||
if kid, encryptedKey, ok := p.GetEncryptedKey(); ok {
|
||||
key, ok := auth.provisioners.LoadEncryptedKey(kid)
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, p.EncryptedKey, key)
|
||||
assert.Equals(t, encryptedKey, key)
|
||||
}
|
||||
}
|
||||
// sanity check
|
||||
_, ok = auth.provisionerIDIndex.Load("fooo")
|
||||
_, ok = auth.provisioners.Load("fooo")
|
||||
assert.False(t, ok)
|
||||
|
||||
assert.Equals(t, auth.audiences, []string{
|
||||
"step-certificate-authority",
|
||||
"https://127.0.0.1/sign",
|
||||
"https://127.0.0.1/1.0/sign",
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -2,14 +2,13 @@ package authority
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
type idUsed struct {
|
||||
|
@ -17,49 +16,21 @@ type idUsed struct {
|
|||
Subject string `json:"sub,omitempty"`
|
||||
}
|
||||
|
||||
// Claims extends jwt.Claims with step attributes.
|
||||
// Claims extends jose.Claims with step attributes.
|
||||
type Claims struct {
|
||||
jwt.Claims
|
||||
SANs []string `json:"sans,omitempty"`
|
||||
}
|
||||
|
||||
// matchesAudience returns true if A and B share at least one element.
|
||||
func matchesAudience(as, bs []string) bool {
|
||||
if len(bs) == 0 || len(as) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, b := range bs {
|
||||
for _, a := range as {
|
||||
if b == a || stripPort(a) == stripPort(b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripPort attempts to strip the port from the given url. If parsing the url
|
||||
// produces errors it will just return the passed argument.
|
||||
func stripPort(rawurl string) string {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
return rawurl
|
||||
}
|
||||
u.Host = u.Hostname()
|
||||
return u.String()
|
||||
jose.Claims
|
||||
SANs []string `json:"sans,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
}
|
||||
|
||||
// Authorize authorizes a signature request by validating and authenticating
|
||||
// a OTT that must be sent w/ the request.
|
||||
func (a *Authority) Authorize(ott string) ([]interface{}, error) {
|
||||
var (
|
||||
errContext = map[string]interface{}{"ott": ott}
|
||||
claims = Claims{}
|
||||
)
|
||||
func (a *Authority) Authorize(ott string) ([]provisioner.SignOption, error) {
|
||||
var errContext = map[string]interface{}{"ott": ott}
|
||||
|
||||
// Validate payload
|
||||
token, err := jwt.ParseSigned(ott)
|
||||
token, err := jose.ParseSigned(ott)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrapf(err, "authorize: error parsing token"),
|
||||
http.StatusUnauthorized, errContext}
|
||||
|
@ -68,86 +39,52 @@ func (a *Authority) Authorize(ott string) ([]interface{}, error) {
|
|||
// Get claims w/out verification. We need to look up the provisioner
|
||||
// key in order to verify the claims and we need the issuer from the claims
|
||||
// before we can look up the provisioner.
|
||||
var claims Claims
|
||||
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, &apiError{err, http.StatusUnauthorized, errContext}
|
||||
}
|
||||
kid := token.Headers[0].KeyID // JWT will only have 1 header.
|
||||
if len(kid) == 0 {
|
||||
return nil, &apiError{errors.New("authorize: token KeyID cannot be empty"),
|
||||
http.StatusUnauthorized, errContext}
|
||||
}
|
||||
pid := claims.Issuer + ":" + kid
|
||||
val, ok := a.provisionerIDIndex.Load(pid)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("authorize: provisioner with id %s not found", pid),
|
||||
http.StatusUnauthorized, errContext}
|
||||
}
|
||||
p, ok := val.(*Provisioner)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("authorize: invalid provisioner type"),
|
||||
http.StatusInternalServerError, errContext}
|
||||
}
|
||||
|
||||
if err = token.Claims(p.Key, &claims); err != nil {
|
||||
return nil, &apiError{err, http.StatusUnauthorized, errContext}
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
// more than a few minutes.
|
||||
if err = claims.ValidateWithLeeway(jwt.Expected{
|
||||
Issuer: p.Name,
|
||||
}, time.Minute); err != nil {
|
||||
return nil, &apiError{errors.Wrapf(err, "authorize: invalid token"),
|
||||
http.StatusUnauthorized, errContext}
|
||||
}
|
||||
|
||||
// Do not accept tokens issued before the start of the ca.
|
||||
// This check is meant as a stopgap solution to the current lack of a persistence layer.
|
||||
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
|
||||
if claims.IssuedAt > 0 && claims.IssuedAt.Time().Before(a.startTime) {
|
||||
return nil, &apiError{errors.New("token issued before the bootstrap of certificate authority"),
|
||||
return nil, &apiError{errors.New("authorize: token issued before the bootstrap of certificate authority"),
|
||||
http.StatusUnauthorized, errContext}
|
||||
}
|
||||
}
|
||||
|
||||
if !matchesAudience(claims.Audience, a.audiences) {
|
||||
return nil, &apiError{errors.New("authorize: token audience invalid"), http.StatusUnauthorized,
|
||||
errContext}
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, &apiError{errors.New("authorize: token subject cannot be empty"),
|
||||
// This method will also validate the audiences for JWK provisioners.
|
||||
p, ok := a.provisioners.LoadByToken(token, &claims.Claims)
|
||||
if !ok {
|
||||
return nil, &apiError{
|
||||
errors.Errorf("authorize: provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")),
|
||||
http.StatusUnauthorized, errContext}
|
||||
}
|
||||
|
||||
// NOTE: This is for backwards compatibility with older versions of cli
|
||||
// and certificates. Older versions added the token subject as the only SAN
|
||||
// in a CSR by default.
|
||||
if len(claims.SANs) == 0 {
|
||||
claims.SANs = []string{claims.Subject}
|
||||
}
|
||||
dnsNames, ips := x509util.SplitSANs(claims.SANs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signOps := []interface{}{
|
||||
&commonNameClaim{claims.Subject},
|
||||
&dnsNamesClaim{dnsNames},
|
||||
&ipAddressesClaim{ips},
|
||||
p,
|
||||
}
|
||||
|
||||
// Store the token to protect against reuse.
|
||||
if _, ok := a.ottMap.LoadOrStore(claims.ID, &idUsed{
|
||||
UsedAt: time.Now().Unix(),
|
||||
Subject: claims.Subject,
|
||||
}); ok {
|
||||
return nil, &apiError{errors.Errorf("token already used"), http.StatusUnauthorized,
|
||||
errContext}
|
||||
var reuseKey string
|
||||
switch p.GetType() {
|
||||
case provisioner.TypeJWK:
|
||||
reuseKey = claims.ID
|
||||
case provisioner.TypeOIDC:
|
||||
reuseKey = claims.Nonce
|
||||
}
|
||||
if reuseKey != "" {
|
||||
if _, ok := a.ottMap.LoadOrStore(reuseKey, &idUsed{
|
||||
UsedAt: time.Now().Unix(),
|
||||
Subject: claims.Subject,
|
||||
}); ok {
|
||||
return nil, &apiError{errors.Errorf("authorize: token already used"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
}
|
||||
|
||||
return signOps, nil
|
||||
// Call the provisioner Authorize method to get the signing options
|
||||
opts, err := p.Authorize(ott)
|
||||
if err != nil {
|
||||
return nil, &apiError{errors.Wrap(err, "authorize"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// authorizeRenewal tries to locate the step provisioner extension, and checks
|
||||
|
@ -157,46 +94,20 @@ func (a *Authority) Authorize(ott string) ([]interface{}, error) {
|
|||
// TODO(mariano): should we authorize by default?
|
||||
func (a *Authority) authorizeRenewal(crt *x509.Certificate) error {
|
||||
errContext := map[string]interface{}{"serialNumber": crt.SerialNumber.String()}
|
||||
for _, e := range crt.Extensions {
|
||||
if e.Id.Equal(stepOIDProvisioner) {
|
||||
var provisioner stepProvisionerASN1
|
||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||
return &apiError{
|
||||
err: errors.Wrap(err, "error decoding step provisioner extension"),
|
||||
code: http.StatusInternalServerError,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
|
||||
// Look for the provisioner, if it cannot be found, renewal will not
|
||||
// be authorized.
|
||||
pid := string(provisioner.Name) + ":" + string(provisioner.CredentialID)
|
||||
val, ok := a.provisionerIDIndex.Load(pid)
|
||||
if !ok {
|
||||
return &apiError{
|
||||
err: errors.Errorf("not found: provisioner %s", pid),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
p, ok := val.(*Provisioner)
|
||||
if !ok {
|
||||
return &apiError{
|
||||
err: errors.Errorf("invalid type: provisioner %s, type %T", pid, val),
|
||||
code: http.StatusInternalServerError,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
if p.Claims.IsDisableRenewal() {
|
||||
return &apiError{
|
||||
err: errors.Errorf("renew disabled: provisioner %s", pid),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
p, ok := a.provisioners.LoadByCertificate(crt)
|
||||
if !ok {
|
||||
return &apiError{
|
||||
err: errors.New("provisioner not found"),
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
if err := p.AuthorizeRenewal(crt); err != nil {
|
||||
return &apiError{
|
||||
err: err,
|
||||
code: http.StatusUnauthorized,
|
||||
context: errContext,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,100 +7,52 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
stepJOSE "github.com/smallstep/cli/jose"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func TestMatchesAudience(t *testing.T) {
|
||||
type matchesTest struct {
|
||||
a, b []string
|
||||
exp bool
|
||||
func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
tests := map[string]matchesTest{
|
||||
"false arg1 empty": {
|
||||
a: []string{},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: false,
|
||||
},
|
||||
"false arg2 empty": {
|
||||
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{},
|
||||
exp: false,
|
||||
},
|
||||
"false arg1,arg2 empty": {
|
||||
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{"step-gateway", "step-cli"},
|
||||
exp: false,
|
||||
},
|
||||
"false": {
|
||||
a: []string{"step-gateway", "step-cli"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: false,
|
||||
},
|
||||
"true": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: true,
|
||||
},
|
||||
"true,portsA": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: true,
|
||||
},
|
||||
"true,portsB": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"},
|
||||
exp: true,
|
||||
},
|
||||
"true,portsAB": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"},
|
||||
exp: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripPort(t *testing.T) {
|
||||
type args struct {
|
||||
rawurl string
|
||||
id, err := randutil.ASCII(64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
|
||||
claims := struct {
|
||||
jose.Claims
|
||||
SANS []string `json:"sans"`
|
||||
}{
|
||||
{"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"},
|
||||
{"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"},
|
||||
{"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := stripPort(tt.args.rawurl); got != tt.want {
|
||||
t.Errorf("stripPort() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
Claims: jose.Claims{
|
||||
ID: id,
|
||||
Subject: sub,
|
||||
Issuer: iss,
|
||||
IssuedAt: jose.NewNumericDate(iat),
|
||||
NotBefore: jose.NewNumericDate(iat),
|
||||
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
|
||||
Audience: []string{aud},
|
||||
},
|
||||
SANS: sans,
|
||||
}
|
||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func TestAuthorize(t *testing.T) {
|
||||
a := testAuthority(t)
|
||||
jwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_priv.jwk",
|
||||
stepJOSE.WithPassword([]byte("pass")))
|
||||
assert.FatalError(t, err)
|
||||
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
|
||||
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||
assert.FatalError(t, err)
|
||||
// Invalid keys
|
||||
keyNoKid := &jose.JSONWebKey{Key: key.Key, KeyID: ""}
|
||||
keyBadKid := &jose.JSONWebKey{Key: key.Key, KeyID: "foo"}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
validIssuer := "step-cli"
|
||||
validAudience := []string{"https://test.ca.smallstep.com/sign"}
|
||||
|
||||
|
@ -120,100 +72,37 @@ func TestAuthorize(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail empty key id": func(t *testing.T) *authorizeTest {
|
||||
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
(&jose.SignerOptions{}).WithType("JWT"))
|
||||
assert.FatalError(t, err)
|
||||
cl := jwt.Claims{
|
||||
Subject: "test.smallstep.com",
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
||||
Audience: validAudience,
|
||||
ID: "43",
|
||||
}
|
||||
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
|
||||
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyNoKid)
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorize: token KeyID cannot be empty"),
|
||||
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"fail provisioner not found": func(t *testing.T) *authorizeTest {
|
||||
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo"))
|
||||
assert.FatalError(t, err)
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "test.smallstep.com",
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
||||
Audience: validAudience,
|
||||
ID: "43",
|
||||
}
|
||||
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
|
||||
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, keyBadKid)
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorize: provisioner with id step-cli:foo not found"),
|
||||
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"fail invalid provisioner": func(t *testing.T) *authorizeTest {
|
||||
_a := testAuthority(t)
|
||||
|
||||
_sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo"))
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_a.provisionerIDIndex.Store(validIssuer+":foo", "42")
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "test.smallstep.com",
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
||||
Audience: validAudience,
|
||||
ID: "43",
|
||||
}
|
||||
raw, err := jwt.Signed(_sig).Claims(cl).CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: _a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorize: invalid provisioner type"),
|
||||
http.StatusInternalServerError, context{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"fail invalid issuer": func(t *testing.T) *authorizeTest {
|
||||
cl := jwt.Claims{
|
||||
Subject: "subject",
|
||||
Issuer: "invalid-issuer",
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
||||
Audience: validAudience,
|
||||
}
|
||||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
||||
raw, err := generateToken("test.smallstep.com", "invalid-issuer", validAudience[0], nil, now, key)
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("authorize: provisioner with id invalid-issuer:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc not found"),
|
||||
err: &apiError{errors.New("authorize: provisioner not found or invalid audience"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"fail empty subject": func(t *testing.T) *authorizeTest {
|
||||
cl := jwt.Claims{
|
||||
Subject: "",
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
||||
Audience: validAudience,
|
||||
}
|
||||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
||||
raw, err := generateToken("", validIssuer, validAudience[0], nil, now, key)
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
|
@ -223,64 +112,34 @@ func TestAuthorize(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail verify-sig-failure": func(t *testing.T) *authorizeTest {
|
||||
_, priv2, err := keys.GenerateDefaultKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
invalidKeySig, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.ES256,
|
||||
Key: priv2,
|
||||
}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
|
||||
assert.FatalError(t, err)
|
||||
cl := jwt.Claims{
|
||||
Subject: "test.smallstep.com",
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
||||
Audience: validAudience,
|
||||
}
|
||||
raw, err := jwt.Signed(invalidKeySig).Claims(cl).CompactSerialize()
|
||||
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("square/go-jose: error in cryptographic primitive"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
ott: raw + "00",
|
||||
err: &apiError{errors.New("authorize: error parsing claims: square/go-jose: error in cryptographic primitive"),
|
||||
http.StatusUnauthorized, context{"ott": raw + "00"}},
|
||||
}
|
||||
},
|
||||
"fail token-already-used": func(t *testing.T) *authorizeTest {
|
||||
cl := jwt.Claims{
|
||||
Subject: "test.smallstep.com",
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
||||
Audience: validAudience,
|
||||
ID: "42",
|
||||
}
|
||||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
||||
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
|
||||
assert.FatalError(t, err)
|
||||
_, err = a.Authorize(raw)
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
ott: raw,
|
||||
err: &apiError{errors.New("token already used"),
|
||||
err: &apiError{errors.New("authorize: token already used"),
|
||||
http.StatusUnauthorized, context{"ott": raw}},
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *authorizeTest {
|
||||
cl := jwt.Claims{
|
||||
Subject: "test.smallstep.com",
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
||||
Audience: validAudience,
|
||||
ID: "43",
|
||||
}
|
||||
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
|
||||
raw, err := generateToken("test.smallstep.com", validIssuer, validAudience[0], nil, now, key)
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
ott: raw,
|
||||
res: []interface{}{"1", "2", "3", "4"},
|
||||
res: []interface{}{"1", "2", "3", "4", "5", "6"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
|
@ -1,117 +0,0 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
x509 "github.com/smallstep/cli/pkg/x509"
|
||||
)
|
||||
|
||||
// certClaim interface is implemented by types used to validate specific claims in a
|
||||
// certificate request.
|
||||
type certClaim interface {
|
||||
Valid(crt *x509.Certificate) error
|
||||
}
|
||||
|
||||
// ValidateClaims returns nil if all the claims are validated, it will return
|
||||
// the first error if a claim fails.
|
||||
func validateClaims(crt *x509.Certificate, claims []certClaim) (err error) {
|
||||
for _, c := range claims {
|
||||
if err = c.Valid(crt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// commonNameClaim validates the common name of a certificate request.
|
||||
type commonNameClaim struct {
|
||||
name string
|
||||
}
|
||||
|
||||
// Valid checks that certificate request common name matches the one configured.
|
||||
func (c *commonNameClaim) Valid(crt *x509.Certificate) error {
|
||||
if crt.Subject.CommonName == "" {
|
||||
return errors.New("common name cannot be empty")
|
||||
}
|
||||
if crt.Subject.CommonName != c.name {
|
||||
return errors.Errorf("common name claim failed - got %s, want %s", crt.Subject.CommonName, c.name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dnsNamesClaim struct {
|
||||
names []string
|
||||
}
|
||||
|
||||
// Valid checks that certificate request DNS Names match those configured in
|
||||
// the bootstrap (token) flow.
|
||||
func (c *dnsNamesClaim) Valid(crt *x509.Certificate) error {
|
||||
tokMap := make(map[string]int)
|
||||
for _, e := range c.names {
|
||||
tokMap[e] = 1
|
||||
}
|
||||
crtMap := make(map[string]int)
|
||||
for _, e := range crt.DNSNames {
|
||||
crtMap[e] = 1
|
||||
}
|
||||
if !reflect.DeepEqual(tokMap, crtMap) {
|
||||
return errors.Errorf("DNS names claim failed - got %s, want %s", crt.DNSNames, c.names)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ipAddressesClaim struct {
|
||||
ips []net.IP
|
||||
}
|
||||
|
||||
// Valid checks that certificate request IP Addresses match those configured in
|
||||
// the bootstrap (token) flow.
|
||||
func (c *ipAddressesClaim) Valid(crt *x509.Certificate) error {
|
||||
tokMap := make(map[string]int)
|
||||
for _, e := range c.ips {
|
||||
tokMap[e.String()] = 1
|
||||
}
|
||||
crtMap := make(map[string]int)
|
||||
for _, e := range crt.IPAddresses {
|
||||
crtMap[e.String()] = 1
|
||||
}
|
||||
if !reflect.DeepEqual(tokMap, crtMap) {
|
||||
return errors.Errorf("IP Addresses claim failed - got %v, want %v", crt.IPAddresses, c.ips)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// certTemporalClaim validates the certificate temporal validity settings.
|
||||
type certTemporalClaim struct {
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
}
|
||||
|
||||
// Validate validates the certificate temporal validity settings.
|
||||
func (ctc *certTemporalClaim) Valid(crt *x509.Certificate) error {
|
||||
var (
|
||||
na = crt.NotAfter
|
||||
nb = crt.NotBefore
|
||||
d = na.Sub(nb)
|
||||
now = time.Now()
|
||||
)
|
||||
|
||||
if na.Before(now) {
|
||||
return errors.Errorf("NotAfter: %v cannot be in the past", na)
|
||||
}
|
||||
if na.Before(nb) {
|
||||
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
|
||||
}
|
||||
if d < ctc.min {
|
||||
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
|
||||
d, ctc.min)
|
||||
}
|
||||
if d > ctc.max {
|
||||
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
|
||||
d, ctc.max)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,132 +0,0 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"crypto/x509/pkix"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
x509 "github.com/smallstep/cli/pkg/x509"
|
||||
)
|
||||
|
||||
func TestCommonNameClaim_Valid(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
cnc certClaim
|
||||
crt *x509.Certificate
|
||||
err error
|
||||
}{
|
||||
"empty-common-name": {
|
||||
cnc: &commonNameClaim{name: "foo"},
|
||||
crt: &x509.Certificate{},
|
||||
err: errors.New("common name cannot be empty"),
|
||||
},
|
||||
"wrong-common-name": {
|
||||
cnc: &commonNameClaim{name: "foo"},
|
||||
crt: &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}},
|
||||
err: errors.New("common name claim failed - got bar, want foo"),
|
||||
},
|
||||
"ok": {
|
||||
cnc: &commonNameClaim{name: "foo"},
|
||||
crt: &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.cnc.Valid(tc.crt)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAddressesClaim_Valid(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
iac certClaim
|
||||
crt *x509.Certificate
|
||||
err error
|
||||
}{
|
||||
"unexpected-ip-in-crt": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}},
|
||||
err: errors.New("IP Addresses claim failed - got [127.0.0.1 1.1.1.1], want [127.0.0.1]"),
|
||||
},
|
||||
"missing-ip-in-crt": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.1.1.1")}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
err: errors.New("IP Addresses claim failed - got [127.0.0.1], want [127.0.0.1 1.1.1.1]"),
|
||||
},
|
||||
"invalid-matcher-nonempty-ips": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
err: errors.New("IP Addresses claim failed - got [127.0.0.1], want []"),
|
||||
},
|
||||
"ok": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
},
|
||||
"ok-multiple-identical-ip-entries": {
|
||||
iac: &ipAddressesClaim{ips: []net.IP{net.ParseIP("127.0.0.1")}},
|
||||
crt: &x509.Certificate{IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1")}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.iac.Valid(tc.crt)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSNamesClaim_Valid(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
dnc certClaim
|
||||
crt *x509.Certificate
|
||||
err error
|
||||
}{
|
||||
"unexpected-dns-name-in-crt": {
|
||||
dnc: &dnsNamesClaim{names: []string{"foo"}},
|
||||
crt: &x509.Certificate{DNSNames: []string{"foo", "bar"}},
|
||||
err: errors.New("DNS names claim failed - got [foo bar], want [foo]"),
|
||||
},
|
||||
"ok": {
|
||||
dnc: &dnsNamesClaim{names: []string{"foo", "bar"}},
|
||||
crt: &x509.Certificate{DNSNames: []string{"bar", "foo"}},
|
||||
},
|
||||
"missing-dns-name-in-crt": {
|
||||
dnc: &dnsNamesClaim{names: []string{"foo", "bar"}},
|
||||
crt: &x509.Certificate{DNSNames: []string{"foo"}},
|
||||
err: errors.New("DNS names claim failed - got [foo], want [foo bar]"),
|
||||
},
|
||||
"ok-multiple-identical-dns-entries": {
|
||||
dnc: &dnsNamesClaim{names: []string{"foo"}},
|
||||
crt: &x509.Certificate{DNSNames: []string{"foo", "foo", "foo"}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.dnc.Valid(tc.crt)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,11 +2,13 @@ package authority
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
)
|
||||
|
@ -25,10 +27,10 @@ var (
|
|||
Renegotiation: false,
|
||||
}
|
||||
defaultDisableRenewal = false
|
||||
globalProvisionerClaims = ProvisionerClaims{
|
||||
MinTLSDur: &Duration{5 * time.Minute},
|
||||
MaxTLSDur: &Duration{24 * time.Hour},
|
||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}
|
||||
)
|
||||
|
@ -50,16 +52,15 @@ type Config struct {
|
|||
|
||||
// AuthConfig represents the configuration options for the authority.
|
||||
type AuthConfig struct {
|
||||
Provisioners []*Provisioner `json:"provisioners,omitempty"`
|
||||
Template *x509util.ASN1DN `json:"template,omitempty"`
|
||||
Claims *ProvisionerClaims `json:"claims,omitempty"`
|
||||
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
||||
Provisioners provisioner.List `json:"provisioners"`
|
||||
Template *x509util.ASN1DN `json:"template,omitempty"`
|
||||
Claims *provisioner.Claims `json:"claims,omitempty"`
|
||||
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
||||
}
|
||||
|
||||
// Validate validates the authority configuration.
|
||||
func (c *AuthConfig) Validate() error {
|
||||
func (c *AuthConfig) Validate(audiences []string) error {
|
||||
var err error
|
||||
|
||||
if c == nil {
|
||||
return errors.New("authority cannot be undefined")
|
||||
}
|
||||
|
@ -70,11 +71,18 @@ func (c *AuthConfig) Validate() error {
|
|||
if c.Claims, err = c.Claims.Init(&globalProvisionerClaims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize provisioners
|
||||
config := provisioner.Config{
|
||||
Claims: *c.Claims,
|
||||
Audiences: audiences,
|
||||
}
|
||||
for _, p := range c.Provisioners {
|
||||
if err := p.Init(c.Claims); err != nil {
|
||||
if err := p.Init(config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if c.Template == nil {
|
||||
c.Template = &x509util.ASN1DN{}
|
||||
}
|
||||
|
@ -153,5 +161,16 @@ func (c *Config) Validate() error {
|
|||
c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation
|
||||
}
|
||||
|
||||
return c.AuthorityConfig.Validate()
|
||||
return c.AuthorityConfig.Validate(c.getAudiences())
|
||||
}
|
||||
|
||||
// getAudiences returns the legacy and possible urls without the ports that will
|
||||
// be used as the default provisioner audiences. The CA might have proxies in
|
||||
// front so we cannot rely on the port.
|
||||
func (c *Config) getAudiences() []string {
|
||||
audiences := []string{legacyAuthority}
|
||||
for _, name := range c.DNSNames {
|
||||
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
|
||||
}
|
||||
return audiences
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
stepJOSE "github.com/smallstep/cli/jose"
|
||||
|
@ -17,13 +18,13 @@ func TestConfigValidate(t *testing.T) {
|
|||
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
ac := &AuthConfig{
|
||||
Provisioners: []*Provisioner{
|
||||
{
|
||||
Provisioners: provisioner.List{
|
||||
&provisioner.JWK{
|
||||
Name: "Max",
|
||||
Type: "JWK",
|
||||
Key: maxjwk,
|
||||
},
|
||||
{
|
||||
&provisioner.JWK{
|
||||
Name: "step-cli",
|
||||
Type: "JWK",
|
||||
Key: clijwk,
|
||||
|
@ -229,13 +230,13 @@ func TestAuthConfigValidate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
clijwk, err := stepJOSE.ParseKey("testdata/secrets/step_cli_key_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
p := []*Provisioner{
|
||||
{
|
||||
p := provisioner.List{
|
||||
&provisioner.JWK{
|
||||
Name: "Max",
|
||||
Type: "JWK",
|
||||
Key: maxjwk,
|
||||
},
|
||||
{
|
||||
&provisioner.JWK{
|
||||
Name: "step-cli",
|
||||
Type: "JWK",
|
||||
Key: clijwk,
|
||||
|
@ -263,9 +264,9 @@ func TestAuthConfigValidate(t *testing.T) {
|
|||
"fail-invalid-provisioners": func(t *testing.T) AuthConfigValidateTest {
|
||||
return AuthConfigValidateTest{
|
||||
ac: &AuthConfig{
|
||||
Provisioners: []*Provisioner{
|
||||
{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
||||
{Name: "foo", Key: &jose.JSONWebKey{}},
|
||||
Provisioners: provisioner.List{
|
||||
&provisioner.JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
||||
&provisioner.JWK{Name: "foo", Key: &jose.JSONWebKey{}},
|
||||
},
|
||||
},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
|
@ -293,7 +294,7 @@ func TestAuthConfigValidate(t *testing.T) {
|
|||
for name, get := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := get(t)
|
||||
err := tc.ac.Validate()
|
||||
err := tc.ac.Validate([]string{})
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
package authority
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// ProvisionerClaims so that individual provisioners can override global claims.
|
||||
type ProvisionerClaims struct {
|
||||
globalClaims *ProvisionerClaims
|
||||
// Claims so that individual provisioners can override global claims.
|
||||
type Claims struct {
|
||||
globalClaims *Claims
|
||||
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
||||
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
||||
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
||||
|
@ -19,19 +16,18 @@ type ProvisionerClaims struct {
|
|||
}
|
||||
|
||||
// Init initializes and validates the individual provisioner claims.
|
||||
func (pc *ProvisionerClaims) Init(global *ProvisionerClaims) (*ProvisionerClaims, error) {
|
||||
func (pc *Claims) Init(global *Claims) (*Claims, error) {
|
||||
if pc == nil {
|
||||
pc = &ProvisionerClaims{}
|
||||
pc = &Claims{}
|
||||
}
|
||||
pc.globalClaims = global
|
||||
err := pc.Validate()
|
||||
return pc, err
|
||||
return pc, pc.Validate()
|
||||
}
|
||||
|
||||
// DefaultTLSCertDuration returns the default TLS cert duration for the
|
||||
// provisioner. If the default is not set within the provisioner, then the global
|
||||
// default from the authority configuration will be used.
|
||||
func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration {
|
||||
func (pc *Claims) DefaultTLSCertDuration() time.Duration {
|
||||
if pc.DefaultTLSDur == nil || pc.DefaultTLSDur.Duration == 0 {
|
||||
return pc.globalClaims.DefaultTLSCertDuration()
|
||||
}
|
||||
|
@ -41,7 +37,7 @@ func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration {
|
|||
// MinTLSCertDuration returns the minimum TLS cert duration for the provisioner.
|
||||
// If the minimum is not set within the provisioner, then the global
|
||||
// minimum from the authority configuration will be used.
|
||||
func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration {
|
||||
func (pc *Claims) MinTLSCertDuration() time.Duration {
|
||||
if pc.MinTLSDur == nil || pc.MinTLSDur.Duration == 0 {
|
||||
return pc.globalClaims.MinTLSCertDuration()
|
||||
}
|
||||
|
@ -51,7 +47,7 @@ func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration {
|
|||
// MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner.
|
||||
// If the maximum is not set within the provisioner, then the global
|
||||
// maximum from the authority configuration will be used.
|
||||
func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration {
|
||||
func (pc *Claims) MaxTLSCertDuration() time.Duration {
|
||||
if pc.MaxTLSDur == nil || pc.MaxTLSDur.Duration == 0 {
|
||||
return pc.globalClaims.MaxTLSCertDuration()
|
||||
}
|
||||
|
@ -61,7 +57,7 @@ func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration {
|
|||
// IsDisableRenewal returns if the renewal flow is disabled for the
|
||||
// provisioner. If the property is not set within the provisioner, then the
|
||||
// global value from the authority configuration will be used.
|
||||
func (pc *ProvisionerClaims) IsDisableRenewal() bool {
|
||||
func (pc *Claims) IsDisableRenewal() bool {
|
||||
if pc.DisableRenewal == nil {
|
||||
return pc.globalClaims.IsDisableRenewal()
|
||||
}
|
||||
|
@ -69,7 +65,7 @@ func (pc *ProvisionerClaims) IsDisableRenewal() bool {
|
|||
}
|
||||
|
||||
// Validate validates and modifies the Claims with default values.
|
||||
func (pc *ProvisionerClaims) Validate() error {
|
||||
func (pc *Claims) Validate() error {
|
||||
var (
|
||||
min = pc.MinTLSCertDuration()
|
||||
max = pc.MaxTLSCertDuration()
|
||||
|
@ -93,52 +89,3 @@ func (pc *ProvisionerClaims) Validate() error {
|
|||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Provisioner - authorized entity that can sign tokens necessary for signature requests.
|
||||
type Provisioner struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Key *jose.JSONWebKey `json:"key,omitempty"`
|
||||
EncryptedKey string `json:"encryptedKey,omitempty"`
|
||||
Claims *ProvisionerClaims `json:"claims,omitempty"`
|
||||
}
|
||||
|
||||
// Init initializes and validates a the fields of Provisioner type.
|
||||
func (p *Provisioner) Init(global *ProvisionerClaims) error {
|
||||
switch {
|
||||
case p.Name == "":
|
||||
return errors.New("provisioner name cannot be empty")
|
||||
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
|
||||
case p.Key == nil:
|
||||
return errors.New("provisioner key cannot be empty")
|
||||
}
|
||||
|
||||
var err error
|
||||
p.Claims, err = p.Claims.Init(global)
|
||||
return err
|
||||
}
|
||||
|
||||
// getTLSApps returns a list of modifiers and validators that will be applied to
|
||||
// the certificate.
|
||||
func (p *Provisioner) getTLSApps(so SignOptions) ([]x509util.WithOption, []certClaim, error) {
|
||||
c := p.Claims
|
||||
return []x509util.WithOption{
|
||||
x509util.WithNotBeforeAfterDuration(so.NotBefore,
|
||||
so.NotAfter, c.DefaultTLSCertDuration()),
|
||||
withProvisionerOID(p.Name, p.Key.KeyID),
|
||||
}, []certClaim{
|
||||
&certTemporalClaim{
|
||||
min: c.MinTLSCertDuration(),
|
||||
max: c.MaxTLSCertDuration(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ID returns the provisioner identifier. The name and credential id should
|
||||
// uniquely identify any provisioner.
|
||||
func (p *Provisioner) ID() string {
|
||||
return p.Name + ":" + p.Key.KeyID
|
||||
}
|
212
authority/provisioner/collection.go
Normal file
212
authority/provisioner/collection.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||
const DefaultProvisionersLimit = 20
|
||||
|
||||
// DefaultProvisionersMax is the maximum limit for listing provisioners.
|
||||
const DefaultProvisionersMax = 100
|
||||
|
||||
type uidProvisioner struct {
|
||||
provisioner Interface
|
||||
uid string
|
||||
}
|
||||
|
||||
type provisionerSlice []uidProvisioner
|
||||
|
||||
func (p provisionerSlice) Len() int { return len(p) }
|
||||
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
// Collection is a memory map of provisioners.
|
||||
type Collection struct {
|
||||
byID *sync.Map
|
||||
byKey *sync.Map
|
||||
sorted provisionerSlice
|
||||
audiences []string
|
||||
}
|
||||
|
||||
// NewCollection initializes a collection of provisioners. The given list of
|
||||
// audiences are the audiences used by the JWT provisioner.
|
||||
func NewCollection(audiences []string) *Collection {
|
||||
return &Collection{
|
||||
byID: new(sync.Map),
|
||||
byKey: new(sync.Map),
|
||||
audiences: audiences,
|
||||
}
|
||||
}
|
||||
|
||||
// Load a provisioner by the ID.
|
||||
func (c *Collection) Load(id string) (Interface, bool) {
|
||||
return loadProvisioner(c.byID, id)
|
||||
}
|
||||
|
||||
// LoadByToken parses the token claims and loads the provisioner associated.
|
||||
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) {
|
||||
// match with server audiences
|
||||
if matchesAudience(claims.Audience, c.audiences) {
|
||||
// If matches with stored audiences it will be a JWT token (default), and
|
||||
// the id would be <issuer>:<kid>.
|
||||
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
|
||||
}
|
||||
|
||||
// The ID will be just the clientID stored in azp or aud.
|
||||
var payload openIDPayload
|
||||
if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
// audience is required
|
||||
if len(payload.Audience) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
if len(payload.AuthorizedParty) > 0 {
|
||||
return c.Load(payload.AuthorizedParty)
|
||||
}
|
||||
return c.Load(payload.Audience[0])
|
||||
}
|
||||
|
||||
// LoadByCertificate looks for the provisioner extension and extracts the
|
||||
// proper id to load the provisioner.
|
||||
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
|
||||
for _, e := range cert.Extensions {
|
||||
if e.Id.Equal(stepOIDProvisioner) {
|
||||
var provisioner stepProvisionerASN1
|
||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if provisioner.Type == int(TypeJWK) {
|
||||
return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID))
|
||||
}
|
||||
return c.Load(string(provisioner.CredentialID))
|
||||
}
|
||||
}
|
||||
|
||||
// Default to noop provisioner if an extension is not found. This allows to
|
||||
// accept a renewal of a cert without the provisioner extension.
|
||||
return &noop{}, true
|
||||
}
|
||||
|
||||
// LoadEncryptedKey returns an encrypted key by indexed by KeyID. At this moment
|
||||
// only JWK encrypted keys are indexed by KeyID.
|
||||
func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
|
||||
p, ok := loadProvisioner(c.byKey, keyID)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
_, key, ok := p.GetEncryptedKey()
|
||||
return key, ok
|
||||
}
|
||||
|
||||
// Store adds a provisioner to the collection and enforces the uniqueness of
|
||||
// provisioner IDs.
|
||||
func (c *Collection) Store(p Interface) error {
|
||||
// Store provisioner always in byID. ID must be unique.
|
||||
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true {
|
||||
return errors.New("cannot add multiple provisioners with the same id")
|
||||
}
|
||||
|
||||
// Store provisioner in byKey if EncryptedKey is defined.
|
||||
if kid, _, ok := p.GetEncryptedKey(); ok {
|
||||
c.byKey.Store(kid, p)
|
||||
}
|
||||
|
||||
// Store sorted provisioners.
|
||||
// Use the first 4 bytes (32bit) of the sum to insert the order
|
||||
// Using big endian format to get the strings sorted:
|
||||
// 0x00000000, 0x00000001, 0x00000002, ...
|
||||
bi := make([]byte, 4)
|
||||
sum := provisionerSum(p)
|
||||
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
|
||||
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
||||
c.sorted = append(c.sorted, uidProvisioner{
|
||||
provisioner: p,
|
||||
uid: hex.EncodeToString(sum),
|
||||
})
|
||||
sort.Sort(c.sorted)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find implements pagination on a list of sorted provisioners.
|
||||
func (c *Collection) Find(cursor string, limit int) (List, string) {
|
||||
switch {
|
||||
case limit <= 0:
|
||||
limit = DefaultProvisionersLimit
|
||||
case limit > DefaultProvisionersMax:
|
||||
limit = DefaultProvisionersMax
|
||||
}
|
||||
|
||||
n := c.sorted.Len()
|
||||
cursor = fmt.Sprintf("%040s", cursor)
|
||||
i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor })
|
||||
|
||||
slice := List{}
|
||||
for ; i < n && len(slice) < limit; i++ {
|
||||
slice = append(slice, c.sorted[i].provisioner)
|
||||
}
|
||||
|
||||
if i < n {
|
||||
return slice, strings.TrimLeft(c.sorted[i].uid, "0")
|
||||
}
|
||||
return slice, ""
|
||||
}
|
||||
|
||||
func loadProvisioner(m *sync.Map, key string) (Interface, bool) {
|
||||
i, ok := m.Load(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
p, ok := i.(Interface)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return p, true
|
||||
}
|
||||
|
||||
// provisionerSum returns the SHA1 of the provisioners ID. From this we will
|
||||
// create the unique and sorted id.
|
||||
func provisionerSum(p Interface) []byte {
|
||||
sum := sha1.Sum([]byte(p.GetID()))
|
||||
return sum[:]
|
||||
}
|
||||
|
||||
// matchesAudience returns true if A and B share at least one element.
|
||||
func matchesAudience(as, bs []string) bool {
|
||||
if len(bs) == 0 || len(as) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, b := range bs {
|
||||
for _, a := range as {
|
||||
if b == a || stripPort(a) == stripPort(b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripPort attempts to strip the port from the given url. If parsing the url
|
||||
// produces errors it will just return the passed argument.
|
||||
func stripPort(rawurl string) string {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
return rawurl
|
||||
}
|
||||
u.Host = u.Hostname()
|
||||
return u.String()
|
||||
}
|
390
authority/provisioner/collection_test.go
Normal file
390
authority/provisioner/collection_test.go
Normal file
|
@ -0,0 +1,390 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func TestCollection_Load(t *testing.T) {
|
||||
p, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p.GetID(), p)
|
||||
byID.Store("string", "a-string")
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
}
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok", fields{byID}, args{p.GetID()}, p, true},
|
||||
{"fail", fields{byID}, args{"fail"}, nil, false},
|
||||
{"invalid", fields{byID}, args{"string"}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
}
|
||||
got, got1 := c.Load(tt.args.id)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.Load() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadByToken(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p3, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p1.GetID(), p1)
|
||||
byID.Store(p2.GetID(), p2)
|
||||
byID.Store(p3.GetID(), p3)
|
||||
byID.Store("string", "a-string")
|
||||
|
||||
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
token, err := generateSimpleToken(p1.Name, testAudiences[0], jwk)
|
||||
assert.FatalError(t, err)
|
||||
t1, c1, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
jwk, err = decryptJSONWebKey(p2.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
token, err = generateSimpleToken(p2.Name, testAudiences[1], jwk)
|
||||
assert.FatalError(t, err)
|
||||
t2, c2, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
t3, c3, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
t4, c4, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
audiences []string
|
||||
}
|
||||
type args struct {
|
||||
token *jose.JSONWebToken
|
||||
claims *jose.Claims
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true},
|
||||
{"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true},
|
||||
{"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true},
|
||||
{"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false},
|
||||
{"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
audiences: tt.fields.audiences,
|
||||
}
|
||||
got, got1 := c.LoadByToken(tt.args.token, tt.args.claims)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.LoadByToken() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.LoadByToken() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadByCertificate(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p1.GetID(), p1)
|
||||
byID.Store(p2.GetID(), p2)
|
||||
|
||||
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
||||
assert.FatalError(t, err)
|
||||
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
|
||||
assert.FatalError(t, err)
|
||||
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
ok1Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok1Ext},
|
||||
}
|
||||
ok2Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok2Ext},
|
||||
}
|
||||
notFoundCert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{notFoundExt},
|
||||
}
|
||||
badCert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{
|
||||
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")},
|
||||
},
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
audiences []string
|
||||
}
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
|
||||
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
|
||||
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
||||
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
|
||||
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
audiences: tt.fields.audiences,
|
||||
}
|
||||
got, got1 := c.LoadByCertificate(tt.args.cert)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadEncryptedKey(t *testing.T) {
|
||||
c := NewCollection(testAudiences)
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, c.Store(p1))
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, c.Store(p2))
|
||||
|
||||
// Add oidc in byKey.
|
||||
// It should not happen.
|
||||
p2KeyID := p2.keyStore.keySet.Keys[0].KeyID
|
||||
c.byKey.Store(p2KeyID, p2)
|
||||
|
||||
type args struct {
|
||||
keyID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
want1 bool
|
||||
}{
|
||||
{"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true},
|
||||
{"oidc", args{p2KeyID}, "", false},
|
||||
{"notFound", args{"not-found"}, "", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := c.LoadEncryptedKey(tt.args.keyID)
|
||||
if got != tt.want {
|
||||
t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_Store(t *testing.T) {
|
||||
c := NewCollection(testAudiences)
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
p Interface
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok1", args{p1}, false},
|
||||
{"ok2", args{p2}, false},
|
||||
{"fail1", args{p1}, true},
|
||||
{"fail2", args{p2}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := c.Store(tt.args.p); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_Find(t *testing.T) {
|
||||
c, err := generateCollection(10, 10)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
trim := func(s string) string {
|
||||
return strings.TrimLeft(s, "0")
|
||||
}
|
||||
toList := func(ps provisionerSlice) List {
|
||||
l := List{}
|
||||
for _, p := range ps {
|
||||
l = append(l, p.provisioner)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
type args struct {
|
||||
cursor string
|
||||
limit int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want List
|
||||
want1 string
|
||||
}{
|
||||
{"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""},
|
||||
{"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""},
|
||||
{"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)},
|
||||
{"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""},
|
||||
{"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)},
|
||||
{"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)},
|
||||
{"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""},
|
||||
{"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := c.Find(tt.args.cursor, tt.args.limit)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.Find() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_matchesAudience(t *testing.T) {
|
||||
type matchesTest struct {
|
||||
a, b []string
|
||||
exp bool
|
||||
}
|
||||
tests := map[string]matchesTest{
|
||||
"false arg1 empty": {
|
||||
a: []string{},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: false,
|
||||
},
|
||||
"false arg2 empty": {
|
||||
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{},
|
||||
exp: false,
|
||||
},
|
||||
"false arg1,arg2 empty": {
|
||||
a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{"step-gateway", "step-cli"},
|
||||
exp: false,
|
||||
},
|
||||
"false": {
|
||||
a: []string{"step-gateway", "step-cli"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: false,
|
||||
},
|
||||
"true": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: true,
|
||||
},
|
||||
"true,portsA": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"},
|
||||
exp: true,
|
||||
},
|
||||
"true,portsB": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"},
|
||||
exp: true,
|
||||
},
|
||||
"true,portsAB": {
|
||||
a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"},
|
||||
b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"},
|
||||
exp: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_stripPort(t *testing.T) {
|
||||
type args struct {
|
||||
rawurl string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"},
|
||||
{"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"},
|
||||
{"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := stripPort(tt.args.rawurl); got != tt.want {
|
||||
t.Errorf("stripPort() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
45
authority/provisioner/duration.go
Normal file
45
authority/provisioner/duration.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
// MarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.Duration.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *Duration) UnmarshalJSON(data []byte) (err error) {
|
||||
var (
|
||||
s string
|
||||
_d time.Duration
|
||||
)
|
||||
if d == nil {
|
||||
return errors.New("duration cannot be nil")
|
||||
}
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshaling %s", data)
|
||||
}
|
||||
if _d, err = time.ParseDuration(s); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||
}
|
||||
d.Duration = _d
|
||||
return
|
||||
}
|
61
authority/provisioner/duration_test.go
Normal file
61
authority/provisioner/duration_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDuration_UnmarshalJSON(t *testing.T) {
|
||||
type args struct {
|
||||
data []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
d *Duration
|
||||
args args
|
||||
want *Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty", new(Duration), args{[]byte{}}, new(Duration), true},
|
||||
{"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true},
|
||||
{"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true},
|
||||
{"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true},
|
||||
{"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false},
|
||||
{"nil", nil, args{nil}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(tt.d, tt.want) {
|
||||
t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
d *Duration
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.d.MarshalJSON()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
125
authority/provisioner/jwk.go
Normal file
125
authority/provisioner/jwk.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// jwtPayload extends jwt.Claims with step attributes.
|
||||
type jwtPayload struct {
|
||||
jose.Claims
|
||||
SANs []string `json:"sans,omitempty"`
|
||||
}
|
||||
|
||||
// JWK is the default provisioner, an entity that can sign tokens necessary for
|
||||
// signature requests.
|
||||
type JWK struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
EncryptedKey string `json:"encryptedKey,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
audiences []string
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier. The name and credential id
|
||||
// should uniquely identify any JWK provisioner.
|
||||
func (p *JWK) GetID() string {
|
||||
return p.Name + ":" + p.Key.KeyID
|
||||
}
|
||||
|
||||
// GetName returns the name of the provisioner.
|
||||
func (p *JWK) GetName() string {
|
||||
return p.Name
|
||||
}
|
||||
|
||||
// GetType returns the type of provisioner.
|
||||
func (p *JWK) GetType() Type {
|
||||
return TypeJWK
|
||||
}
|
||||
|
||||
// GetEncryptedKey returns the base provisioner encrypted key if it's defined.
|
||||
func (p *JWK) GetEncryptedKey() (string, string, bool) {
|
||||
return p.Key.KeyID, p.EncryptedKey, len(p.EncryptedKey) > 0
|
||||
}
|
||||
|
||||
// Init initializes and validates the fields of a JWK type.
|
||||
func (p *JWK) Init(config Config) (err error) {
|
||||
switch {
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
case p.Name == "":
|
||||
return errors.New("provisioner name cannot be empty")
|
||||
case p.Key == nil:
|
||||
return errors.New("provisioner key cannot be empty")
|
||||
}
|
||||
p.Claims, err = p.Claims.Init(&config.Claims)
|
||||
p.audiences = config.Audiences
|
||||
return err
|
||||
}
|
||||
|
||||
// Authorize validates the given token.
|
||||
func (p *JWK) Authorize(token string) ([]SignOption, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
}
|
||||
|
||||
var claims jwtPayload
|
||||
if err = jwt.Claims(p.Key, &claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing claims")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
// more than a few minutes.
|
||||
if err = claims.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: p.Name,
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid token")
|
||||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, p.audiences) {
|
||||
return nil, errors.New("invalid token: invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errors.New("token subject cannot be empty")
|
||||
}
|
||||
|
||||
// NOTE: This is for backwards compatibility with older versions of cli
|
||||
// and certificates. Older versions added the token subject as the only SAN
|
||||
// in a CSR by default.
|
||||
if len(claims.SANs) == 0 {
|
||||
claims.SANs = []string{claims.Subject}
|
||||
}
|
||||
|
||||
dnsNames, ips := x509util.SplitSANs(claims.SANs)
|
||||
return []SignOption{
|
||||
commonNameValidator(claims.Subject),
|
||||
dnsNamesValidator(dnsNames),
|
||||
ipAddressesValidator(ips),
|
||||
profileDefaultDuration(p.Claims.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
|
||||
newValidityValidator(p.Claims.MinTLSCertDuration(), p.Claims.MaxTLSCertDuration()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeRenewal returns an error if the renewal is disabled.
|
||||
func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||
if p.Claims.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (p *JWK) AuthorizeRevoke(token string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
256
authority/provisioner/jwk_test.go
Normal file
256
authority/provisioner/jwk_test.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
globalProvisionerClaims = Claims{
|
||||
MinTLSDur: &Duration{5 * time.Minute},
|
||||
MaxTLSDur: &Duration{24 * time.Hour},
|
||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}
|
||||
)
|
||||
|
||||
func TestJWK_Getters(t *testing.T) {
|
||||
p, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
if got := p.GetID(); got != p.Name+":"+p.Key.KeyID {
|
||||
t.Errorf("JWK.GetID() = %v, want %v:%v", got, p.Name, p.Key.KeyID)
|
||||
}
|
||||
if got := p.GetName(); got != p.Name {
|
||||
t.Errorf("JWK.GetName() = %v, want %v", got, p.Name)
|
||||
}
|
||||
if got := p.GetType(); got != TypeJWK {
|
||||
t.Errorf("JWK.GetType() = %v, want %v", got, TypeJWK)
|
||||
}
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
if kid != p.Key.KeyID || key != p.EncryptedKey || ok == false {
|
||||
t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, p.Key.KeyID, p.EncryptedKey, true)
|
||||
}
|
||||
p.EncryptedKey = ""
|
||||
kid, key, ok = p.GetEncryptedKey()
|
||||
if kid != p.Key.KeyID || key != "" || ok == true {
|
||||
t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, p.Key.KeyID, "", false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_Init(t *testing.T) {
|
||||
type ProvisionerValidateTest struct {
|
||||
p *JWK
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) ProvisionerValidateTest{
|
||||
"fail-empty": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{
|
||||
Type: "JWK",
|
||||
},
|
||||
err: errors.New("provisioner name cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{Name: "foo"},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-key": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{Name: "foo", Type: "bar"},
|
||||
err: errors.New("provisioner key cannot be empty"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
config := Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}
|
||||
for name, get := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := get(t)
|
||||
err := tc.p.Init(config)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_Authorize(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
key1, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
key2, err := decryptJSONWebKey(p2.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1)
|
||||
assert.FatalError(t, err)
|
||||
t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2)
|
||||
assert.FatalError(t, err)
|
||||
t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], "", []string{}, time.Now(), key1)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Invalid tokens
|
||||
parts := strings.Split(t1, ".")
|
||||
key3, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
// missing key
|
||||
failKey, err := generateSimpleToken(p1.Name, testAudiences[0], key3)
|
||||
assert.FatalError(t, err)
|
||||
// invalid token
|
||||
failTok := "foo." + parts[1] + "." + parts[2]
|
||||
// invalid claims
|
||||
failClaims := parts[0] + ".foo." + parts[1]
|
||||
// invalid issuer
|
||||
failIss, err := generateSimpleToken("foobar", testAudiences[0], key1)
|
||||
assert.FatalError(t, err)
|
||||
// invalid audience
|
||||
failAud, err := generateSimpleToken(p1.Name, "foobar", key1)
|
||||
assert.FatalError(t, err)
|
||||
// invalid signature
|
||||
failSig := t1[0 : len(t1)-2]
|
||||
// no subject
|
||||
failSub, err := generateToken("", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now(), key1)
|
||||
assert.FatalError(t, err)
|
||||
// expired
|
||||
failExp, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1)
|
||||
assert.FatalError(t, err)
|
||||
// not before
|
||||
failNbf, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Remove encrypted key for p2
|
||||
p2.EncryptedKey = ""
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, false},
|
||||
{"ok-no-encrypted-key", p2, args{t2}, false},
|
||||
{"ok-no-sans", p1, args{t3}, false},
|
||||
{"fail-key", p1, args{failKey}, true},
|
||||
{"fail-token", p1, args{failTok}, true},
|
||||
{"fail-claims", p1, args{failClaims}, true},
|
||||
{"fail-issuer", p1, args{failIss}, true},
|
||||
{"fail-audience", p1, args{failAud}, true},
|
||||
{"fail-signature", p1, args{failSig}, true},
|
||||
{"fail-subject", p1, args{failSub}, true},
|
||||
{"fail-expired", p1, args{failExp}, true},
|
||||
{"fail-not-before", p1, args{failNbf}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.prov.Authorize(tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NotNil(t, got)
|
||||
assert.Len(t, 6, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{
|
||||
globalClaims: &globalProvisionerClaims,
|
||||
DisableRenewal: &disable,
|
||||
}
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWK_AuthorizeRevoke(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
key1, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
t1, err := generateSimpleToken(p1.Name, testAudiences[0], key1)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *JWK
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"disabled", p1, args{t1}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("JWK.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
135
authority/provisioner/keystore.go
Normal file
135
authority/provisioner/keystore.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCacheAge = 12 * time.Hour
|
||||
defaultCacheJitter = 1 * time.Hour
|
||||
)
|
||||
|
||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
||||
|
||||
type keyStore struct {
|
||||
sync.RWMutex
|
||||
uri string
|
||||
keySet jose.JSONWebKeySet
|
||||
timer *time.Timer
|
||||
expiry time.Time
|
||||
jitter time.Duration
|
||||
}
|
||||
|
||||
func newKeyStore(uri string) (*keyStore, error) {
|
||||
keys, age, err := getKeysFromJWKsURI(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ks := &keyStore{
|
||||
uri: uri,
|
||||
keySet: keys,
|
||||
expiry: getExpirationTime(age),
|
||||
jitter: getCacheJitter(age),
|
||||
}
|
||||
next := ks.nextReloadDuration(age)
|
||||
ks.timer = time.AfterFunc(next, ks.reload)
|
||||
return ks, nil
|
||||
}
|
||||
|
||||
func (ks *keyStore) Close() {
|
||||
ks.timer.Stop()
|
||||
}
|
||||
|
||||
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
||||
ks.RLock()
|
||||
// Force reload if expiration has passed
|
||||
if time.Now().After(ks.expiry) {
|
||||
ks.RUnlock()
|
||||
ks.reload()
|
||||
ks.RLock()
|
||||
}
|
||||
keys = ks.keySet.Key(kid)
|
||||
ks.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (ks *keyStore) reload() {
|
||||
var next time.Duration
|
||||
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
||||
if err != nil {
|
||||
next = ks.nextReloadDuration(ks.jitter / 2)
|
||||
} else {
|
||||
ks.Lock()
|
||||
ks.keySet = keys
|
||||
ks.expiry = getExpirationTime(age)
|
||||
ks.jitter = getCacheJitter(age)
|
||||
next = ks.nextReloadDuration(age)
|
||||
ks.Unlock()
|
||||
}
|
||||
|
||||
ks.Lock()
|
||||
ks.timer.Reset(next)
|
||||
ks.Unlock()
|
||||
}
|
||||
|
||||
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||
n := rand.Int63n(int64(ks.jitter))
|
||||
age -= time.Duration(n)
|
||||
if age < 0 {
|
||||
age = 0
|
||||
}
|
||||
return age
|
||||
}
|
||||
|
||||
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||
var keys jose.JSONWebKeySet
|
||||
resp, err := http.Get(uri)
|
||||
if err != nil {
|
||||
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
|
||||
return keys, 0, errors.Wrapf(err, "error reading %s", uri)
|
||||
}
|
||||
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
|
||||
}
|
||||
|
||||
func getCacheAge(cacheControl string) time.Duration {
|
||||
age := defaultCacheAge
|
||||
if len(cacheControl) > 0 {
|
||||
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
|
||||
if len(match) > 0 {
|
||||
if len(match[0]) == 2 {
|
||||
maxAge := match[0][1]
|
||||
maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
|
||||
if err != nil {
|
||||
return defaultCacheAge
|
||||
}
|
||||
age = time.Duration(maxAgeInt) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
return age
|
||||
}
|
||||
|
||||
func getCacheJitter(age time.Duration) time.Duration {
|
||||
switch {
|
||||
case age > time.Hour:
|
||||
return defaultCacheJitter
|
||||
default:
|
||||
return age / 3
|
||||
}
|
||||
}
|
||||
|
||||
func getExpirationTime(age time.Duration) time.Time {
|
||||
return time.Now().Truncate(time.Second).Add(age)
|
||||
}
|
121
authority/provisioner/keystore_test.go
Normal file
121
authority/provisioner/keystore_test.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func Test_newKeyStore(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
ks, err := newKeyStore(srv.URL)
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
|
||||
type args struct {
|
||||
uri string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want jose.JSONWebKeySet
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{srv.URL}, ks.keySet, false},
|
||||
{"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newKeyStore(tt.args.uri)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
if !reflect.DeepEqual(got.keySet, tt.want) {
|
||||
t.Errorf("newKeyStore() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_keyStore(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
ks, err := newKeyStore(srv.URL + "/random")
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
ks.RLock()
|
||||
keySet1 := ks.keySet
|
||||
ks.RUnlock()
|
||||
// Check contents
|
||||
assert.Len(t, 2, keySet1.Keys)
|
||||
assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID))
|
||||
assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
// Wait for rotation
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
ks.RLock()
|
||||
keySet2 := ks.keySet
|
||||
ks.RUnlock()
|
||||
if reflect.DeepEqual(keySet1, keySet2) {
|
||||
t.Error("keyStore did not rotated")
|
||||
}
|
||||
|
||||
// Check contents
|
||||
assert.Len(t, 2, keySet2.Keys)
|
||||
assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID))
|
||||
assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
// Check hits
|
||||
resp, err := srv.Client().Get(srv.URL + "/hits")
|
||||
assert.FatalError(t, err)
|
||||
hits := struct {
|
||||
Hits int `json:"hits"`
|
||||
}{}
|
||||
defer resp.Body.Close()
|
||||
err = json.NewDecoder(resp.Body).Decode(&hits)
|
||||
assert.FatalError(t, err)
|
||||
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
||||
}
|
||||
|
||||
func Test_keyStore_Get(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
ks, err := newKeyStore(srv.URL)
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
|
||||
type args struct {
|
||||
kid string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
ks *keyStore
|
||||
args args
|
||||
wantKeys []jose.JSONWebKey
|
||||
}{
|
||||
{"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}},
|
||||
{"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}},
|
||||
{"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) {
|
||||
t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
37
authority/provisioner/noop.go
Normal file
37
authority/provisioner/noop.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package provisioner
|
||||
|
||||
import "crypto/x509"
|
||||
|
||||
// noop provisioners is a provisioner that accepts anything.
|
||||
type noop struct{}
|
||||
|
||||
func (p *noop) GetID() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (p *noop) GetName() string {
|
||||
return "noop"
|
||||
}
|
||||
func (p *noop) GetType() Type {
|
||||
return noopType
|
||||
}
|
||||
|
||||
func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func (p *noop) Init(config Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *noop) Authorize(token string) ([]SignOption, error) {
|
||||
return []SignOption{}, nil
|
||||
}
|
||||
|
||||
func (p *noop) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *noop) AuthorizeRevoke(token string) error {
|
||||
return nil
|
||||
}
|
27
authority/provisioner/noop_test.go
Normal file
27
authority/provisioner/noop_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func Test_noop(t *testing.T) {
|
||||
p := noop{}
|
||||
assert.Equals(t, "noop", p.GetID())
|
||||
assert.Equals(t, "noop", p.GetName())
|
||||
assert.Equals(t, noopType, p.GetType())
|
||||
assert.Equals(t, nil, p.Init(Config{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRenewal(&x509.Certificate{}))
|
||||
assert.Equals(t, nil, p.AuthorizeRevoke("foo"))
|
||||
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
assert.Equals(t, "", kid)
|
||||
assert.Equals(t, "", key)
|
||||
assert.Equals(t, false, ok)
|
||||
|
||||
sigOptions, err := p.Authorize("foo")
|
||||
assert.Equals(t, []SignOption{}, sigOptions)
|
||||
assert.Equals(t, nil, err)
|
||||
}
|
243
authority/provisioner/oidc.go
Normal file
243
authority/provisioner/oidc.go
Normal file
|
@ -0,0 +1,243 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// openIDConfiguration contains the necessary properties in the
|
||||
// `/.well-known/openid-configuration` document.
|
||||
type openIDConfiguration struct {
|
||||
Issuer string `json:"issuer"`
|
||||
JWKSetURI string `json:"jwks_uri"`
|
||||
}
|
||||
|
||||
// Validate validates the values in a well-known OpenID configuration endpoint.
|
||||
func (c openIDConfiguration) Validate() error {
|
||||
switch {
|
||||
case c.Issuer == "":
|
||||
return errors.New("issuer cannot be empty")
|
||||
case c.JWKSetURI == "":
|
||||
return errors.New("jwks_uri cannot be empty")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// openIDPayload represents the fields on the id_token JWT payload.
|
||||
type openIDPayload struct {
|
||||
jose.Claims
|
||||
AtHash string `json:"at_hash"`
|
||||
AuthorizedParty string `json:"azp"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Hd string `json:"hd"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// OIDC represents an OAuth 2.0 OpenID Connect provider.
|
||||
//
|
||||
// ClientSecret is mandatory, but it can be an empty string.
|
||||
type OIDC struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
ConfigurationEndpoint string `json:"configurationEndpoint"`
|
||||
Admins []string `json:"admins,omitempty"`
|
||||
Domains []string `json:"domains,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
configuration openIDConfiguration
|
||||
keyStore *keyStore
|
||||
}
|
||||
|
||||
// IsAdmin returns true if the given email is in the Admins whitelist, false
|
||||
// otherwise.
|
||||
func (o *OIDC) IsAdmin(email string) bool {
|
||||
email = sanitizeEmail(email)
|
||||
for _, e := range o.Admins {
|
||||
if email == sanitizeEmail(e) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sanitizeEmail(email string) string {
|
||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||
email = email[:i] + strings.ToLower(email[i:])
|
||||
}
|
||||
return email
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier, the OIDC provisioner the
|
||||
// uses the clientID for this.
|
||||
func (o *OIDC) GetID() string {
|
||||
return o.ClientID
|
||||
}
|
||||
|
||||
// GetName returns the name of the provisioner.
|
||||
func (o *OIDC) GetName() string {
|
||||
return o.Name
|
||||
}
|
||||
|
||||
// GetType returns the type of provisioner.
|
||||
func (o *OIDC) GetType() Type {
|
||||
return TypeOIDC
|
||||
}
|
||||
|
||||
// GetEncryptedKey is not available in an OIDC provisioner.
|
||||
func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Init validates and initializes the OIDC provider.
|
||||
func (o *OIDC) Init(config Config) (err error) {
|
||||
switch {
|
||||
case o.Type == "":
|
||||
return errors.New("type cannot be empty")
|
||||
case o.Name == "":
|
||||
return errors.New("name cannot be empty")
|
||||
case o.ClientID == "":
|
||||
return errors.New("clientID cannot be empty")
|
||||
case o.ConfigurationEndpoint == "":
|
||||
return errors.New("configurationEndpoint cannot be empty")
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if o.Claims, err = o.Claims.Init(&config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
// Decode and validate openid-configuration endpoint
|
||||
if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := o.configuration.Validate(); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint)
|
||||
}
|
||||
// Get JWK key set
|
||||
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePayload validates the given token payload.
|
||||
func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no more
|
||||
// than a few minutes.
|
||||
if err := p.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: o.configuration.Issuer,
|
||||
Audience: jose.Audience{o.ClientID},
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return errors.Wrap(err, "failed to validate payload")
|
||||
}
|
||||
|
||||
// Validate azp if present
|
||||
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
|
||||
return errors.New("failed to validate payload: invalid azp")
|
||||
}
|
||||
|
||||
// Enforce an email claim
|
||||
if p.Email == "" {
|
||||
return errors.New("failed to validate payload: email not found")
|
||||
}
|
||||
|
||||
// Validate domains (case-insensitive)
|
||||
if !o.IsAdmin(p.Email) && len(o.Domains) > 0 {
|
||||
email := sanitizeEmail(p.Email)
|
||||
var found bool
|
||||
for _, d := range o.Domains {
|
||||
if strings.HasSuffix(email, "@"+strings.ToLower(d)) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("failed to validate payload: email is not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authorize validates the given token.
|
||||
func (o *OIDC) Authorize(token string) ([]SignOption, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
}
|
||||
|
||||
// Parse claims to get the kid
|
||||
var claims openIDPayload
|
||||
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing claims")
|
||||
}
|
||||
|
||||
found := false
|
||||
kid := jwt.Headers[0].KeyID
|
||||
keys := o.keyStore.Get(kid)
|
||||
for _, key := range keys {
|
||||
if err := jwt.Claims(key, &claims); err == nil {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.New("cannot validate token")
|
||||
}
|
||||
|
||||
if err := o.ValidatePayload(claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Admins should be able to authorize any SAN
|
||||
if o.IsAdmin(claims.Email) {
|
||||
return []SignOption{
|
||||
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return []SignOption{
|
||||
emailOnlyIdentity(claims.Email),
|
||||
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeRenewal returns an error if the renewal is disabled.
|
||||
func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||
if o.Claims.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", o.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (o *OIDC) AuthorizeRevoke(token string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func getAndDecode(uri string, v interface{}) error {
|
||||
resp, err := http.Get(uri)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to connect to %s", uri)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
|
||||
return errors.Wrapf(err, "error reading %s", uri)
|
||||
}
|
||||
return nil
|
||||
}
|
327
authority/provisioner/oidc_test.go
Normal file
327
authority/provisioner/oidc_test.go
Normal file
|
@ -0,0 +1,327 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func Test_openIDConfiguration_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Issuer string
|
||||
JWKSetURI string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"the-issuer", "the-jwks-uri"}, false},
|
||||
{"no-issuer", fields{"", "the-jwks-uri"}, true},
|
||||
{"no-jwks-uri", fields{"the-issuer", ""}, true},
|
||||
{"empty", fields{"", ""}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := openIDConfiguration{
|
||||
Issuer: tt.fields.Issuer,
|
||||
JWKSetURI: tt.fields.JWKSetURI,
|
||||
}
|
||||
if err := c.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("openIDConfiguration.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_Getters(t *testing.T) {
|
||||
p, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
if got := p.GetID(); got != p.ClientID {
|
||||
t.Errorf("OIDC.GetID() = %v, want %v", got, p.ClientID)
|
||||
}
|
||||
if got := p.GetName(); got != p.Name {
|
||||
t.Errorf("OIDC.GetName() = %v, want %v", got, p.Name)
|
||||
}
|
||||
if got := p.GetType(); got != TypeOIDC {
|
||||
t.Errorf("OIDC.GetType() = %v, want %v", got, TypeOIDC)
|
||||
}
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
if kid != "" || key != "" || ok == true {
|
||||
t.Errorf("OIDC.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, "", "", false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_Init(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
config := Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
Type string
|
||||
Name string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ConfigurationEndpoint string
|
||||
Claims *Claims
|
||||
Admins []string
|
||||
Domains []string
|
||||
}
|
||||
type args struct {
|
||||
config Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
|
||||
{"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}, nil}, args{config}, false},
|
||||
{"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, []string{"smallstep.com"}}, args{config}, false},
|
||||
{"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
|
||||
{"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||
{"no-type", fields{"", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||
{"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||
{"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil}, args{config}, true},
|
||||
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil}, args{config}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &OIDC{
|
||||
Type: tt.fields.Type,
|
||||
Name: tt.fields.Name,
|
||||
ClientID: tt.fields.ClientID,
|
||||
ConfigurationEndpoint: tt.fields.ConfigurationEndpoint,
|
||||
Claims: tt.fields.Claims,
|
||||
Admins: tt.fields.Admins,
|
||||
}
|
||||
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr == false {
|
||||
assert.Len(t, 2, p.keyStore.keySet.Keys)
|
||||
assert.Equals(t, openIDConfiguration{
|
||||
Issuer: "the-issuer",
|
||||
JWKSetURI: srv.URL + "/jwks_uri",
|
||||
}, p.configuration)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_Authorize(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
var keys jose.JSONWebKeySet
|
||||
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||
|
||||
// Create test provisioners
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p3, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
// Admin + Domains
|
||||
p3.Admins = []string{"name@smallstep.com", "root@example.com"}
|
||||
p3.Domains = []string{"smallstep.com"}
|
||||
|
||||
// Update configuration endpoints and initialize
|
||||
config := Config{Claims: globalProvisionerClaims}
|
||||
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
assert.FatalError(t, p1.Init(config))
|
||||
assert.FatalError(t, p2.Init(config))
|
||||
assert.FatalError(t, p3.Init(config))
|
||||
|
||||
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
t2, err := generateSimpleToken("the-issuer", p2.ClientID, &keys.Keys[1])
|
||||
assert.FatalError(t, err)
|
||||
t3, err := generateSimpleToken("the-issuer", p3.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Admin email not in domains
|
||||
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// Invalid email
|
||||
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
failDomain, err := generateToken("subject", "the-issuer", p3.ClientID, "name@example.com", []string{}, time.Now(), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Invalid tokens
|
||||
parts := strings.Split(t1, ".")
|
||||
key, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
// missing key
|
||||
failKey, err := generateSimpleToken("the-issuer", p1.ClientID, key)
|
||||
assert.FatalError(t, err)
|
||||
// invalid token
|
||||
failTok := "foo." + parts[1] + "." + parts[2]
|
||||
// invalid claims
|
||||
failClaims := parts[0] + ".foo." + parts[1]
|
||||
// invalid issuer
|
||||
failIss, err := generateSimpleToken("bad-issuer", p1.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// invalid audience
|
||||
failAud, err := generateSimpleToken("the-issuer", "foobar", &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// invalid signature
|
||||
failSig := t1[0 : len(t1)-2]
|
||||
// expired
|
||||
failExp, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(-360*time.Second), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
// not before
|
||||
failNbf, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(360*time.Second), &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok1", p1, args{t1}, false},
|
||||
{"ok2", p2, args{t2}, false},
|
||||
{"admin", p3, args{t3}, false},
|
||||
{"admin", p3, args{okAdmin}, false},
|
||||
{"fail-email", p3, args{failEmail}, true},
|
||||
{"fail-domain", p3, args{failDomain}, true},
|
||||
{"fail-key", p1, args{failKey}, true},
|
||||
{"fail-token", p1, args{failTok}, true},
|
||||
{"fail-claims", p1, args{failClaims}, true},
|
||||
{"fail-issuer", p1, args{failIss}, true},
|
||||
{"fail-audience", p1, args{failAud}, true},
|
||||
{"fail-signature", p1, args{failSig}, true},
|
||||
{"fail-expired", p1, args{failExp}, true},
|
||||
{"fail-not-before", p1, args{failNbf}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.prov.Authorize(tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Println(tt)
|
||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NotNil(t, got)
|
||||
if tt.name == "admin" {
|
||||
assert.Len(t, 3, got)
|
||||
} else {
|
||||
assert.Len(t, 4, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{
|
||||
globalClaims: &globalProvisionerClaims,
|
||||
DisableRenewal: &disable,
|
||||
}
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, false},
|
||||
{"fail", p2, args{nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
var keys jose.JSONWebKeySet
|
||||
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
|
||||
|
||||
// Create test provisioners
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Update configuration endpoints and initialize
|
||||
config := Config{Claims: globalProvisionerClaims}
|
||||
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
|
||||
assert.FatalError(t, p1.Init(config))
|
||||
|
||||
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *OIDC
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"disabled", p1, args{t1}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sanitizeEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
want string
|
||||
}{
|
||||
{"equal", "name@smallstep.com", "name@smallstep.com"},
|
||||
{"domain-insensitive", "name@SMALLSTEP.COM", "name@smallstep.com"},
|
||||
{"local-sensitive", "NaMe@smallSTEP.CoM", "NaMe@smallstep.com"},
|
||||
{"multiple-@", "NaMe@NaMe@smallSTEP.CoM", "NaMe@NaMe@smallstep.com"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := sanitizeEmail(tt.email); got != tt.want {
|
||||
t.Errorf("sanitizeEmail() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
82
authority/provisioner/provisioner.go
Normal file
82
authority/provisioner/provisioner.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Interface is the interface that all provisioner types must implement.
|
||||
type Interface interface {
|
||||
GetID() string
|
||||
GetName() string
|
||||
GetType() Type
|
||||
GetEncryptedKey() (kid string, key string, ok bool)
|
||||
Init(config Config) error
|
||||
Authorize(token string) ([]SignOption, error)
|
||||
AuthorizeRenewal(cert *x509.Certificate) error
|
||||
AuthorizeRevoke(token string) error
|
||||
}
|
||||
|
||||
// Type indicates the provisioner Type.
|
||||
type Type int
|
||||
|
||||
const (
|
||||
noopType Type = 0
|
||||
|
||||
// TypeJWK is used to indicate the JWK provisioners.
|
||||
TypeJWK Type = 1
|
||||
|
||||
// TypeOIDC is used to indicate the OIDC provisioners.
|
||||
TypeOIDC Type = 2
|
||||
)
|
||||
|
||||
// Config defines the default parameters used in the initialization of
|
||||
// provisioners.
|
||||
type Config struct {
|
||||
// Claims are the default claims.
|
||||
Claims Claims
|
||||
// Audiences are the audiences used in the default provisioner, (JWK).
|
||||
Audiences []string
|
||||
}
|
||||
|
||||
type provisioner struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// List represents a list of provisioners.
|
||||
type List []Interface
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler and allows to unmarshal a list of a
|
||||
// interfaces into the right type.
|
||||
func (l *List) UnmarshalJSON(data []byte) error {
|
||||
ps := []json.RawMessage{}
|
||||
if err := json.Unmarshal(data, &ps); err != nil {
|
||||
return errors.Wrap(err, "error unmarshaling provisioner list")
|
||||
}
|
||||
|
||||
*l = List{}
|
||||
for _, data := range ps {
|
||||
var typ provisioner
|
||||
if err := json.Unmarshal(data, &typ); err != nil {
|
||||
return errors.Errorf("error unmarshaling provisioner")
|
||||
}
|
||||
var p Interface
|
||||
switch strings.ToLower(typ.Type) {
|
||||
case "jwk":
|
||||
p = &JWK{}
|
||||
case "oidc":
|
||||
p = &OIDC{}
|
||||
default:
|
||||
return errors.Errorf("provisioner type %s not supported", typ.Type)
|
||||
}
|
||||
if err := json.Unmarshal(data, p); err != nil {
|
||||
return errors.Errorf("error unmarshaling provisioner")
|
||||
}
|
||||
*l = append(*l, p)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
233
authority/provisioner/sign_options.go
Normal file
233
authority/provisioner/sign_options.go
Normal file
|
@ -0,0 +1,233 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
)
|
||||
|
||||
// Options contains the options that can be passed to the Sign method.
|
||||
type Options struct {
|
||||
NotAfter time.Time `json:"notAfter"`
|
||||
NotBefore time.Time `json:"notBefore"`
|
||||
}
|
||||
|
||||
// SignOption is the interface used to collect all extra options used in the
|
||||
// Sign method.
|
||||
type SignOption interface{}
|
||||
|
||||
// CertificateValidator is the interface used to validate a X.509 certificate.
|
||||
type CertificateValidator interface {
|
||||
SignOption
|
||||
Valid(crt *x509.Certificate) error
|
||||
}
|
||||
|
||||
// CertificateRequestValidator is the interface used to validate a X.509
|
||||
// certificate request.
|
||||
type CertificateRequestValidator interface {
|
||||
SignOption
|
||||
Valid(req *x509.CertificateRequest) error
|
||||
}
|
||||
|
||||
// ProfileModifier is the interface used to add custom options to the profile
|
||||
// constructor. The options are used to modify the final certificate.
|
||||
type ProfileModifier interface {
|
||||
SignOption
|
||||
Option(o Options) x509util.WithOption
|
||||
}
|
||||
|
||||
// profileWithOption is a wrapper against x509util.WithOption to conform the
|
||||
// interface.
|
||||
type profileWithOption x509util.WithOption
|
||||
|
||||
func (v profileWithOption) Option(Options) x509util.WithOption {
|
||||
return x509util.WithOption(v)
|
||||
}
|
||||
|
||||
// profileDefaultDuration is a wrapper against x509util.WithOption to conform the
|
||||
// interface.
|
||||
type profileDefaultDuration time.Duration
|
||||
|
||||
func (v profileDefaultDuration) Option(so Options) x509util.WithOption {
|
||||
return x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, time.Duration(v))
|
||||
}
|
||||
|
||||
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
||||
// SAN provided is the given email address.
|
||||
type emailOnlyIdentity string
|
||||
|
||||
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
|
||||
switch {
|
||||
case len(req.DNSNames) > 0:
|
||||
return errors.New("certificate request cannot contain DNS names")
|
||||
case len(req.IPAddresses) > 0:
|
||||
return errors.New("certificate request cannot contain IP addresses")
|
||||
case len(req.URIs) > 0:
|
||||
return errors.New("certificate request cannot contain URIs")
|
||||
case len(req.EmailAddresses) == 0:
|
||||
return errors.New("certificate request does not contain any email address")
|
||||
case len(req.EmailAddresses) > 1:
|
||||
return errors.New("certificate request does not contain too many email addresses")
|
||||
case req.EmailAddresses[0] == "":
|
||||
return errors.New("certificate request cannot contain an empty email address")
|
||||
case req.EmailAddresses[0] != string(e):
|
||||
return errors.Errorf("certificate request does not contain the valid email address, got %s, want %s", req.EmailAddresses[0], e)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// commonNameValidator validates the common name of a certificate request.
|
||||
type commonNameValidator string
|
||||
|
||||
// Valid checks that certificate request common name matches the one configured.
|
||||
func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
|
||||
if req.Subject.CommonName == "" {
|
||||
return errors.New("certificate request cannot contain an empty common name")
|
||||
}
|
||||
if req.Subject.CommonName != string(v) {
|
||||
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dnsNamesValidator validates the DNS names SAN of a certificate request.
|
||||
type dnsNamesValidator []string
|
||||
|
||||
// Valid checks that certificate request DNS Names match those configured in
|
||||
// the bootstrap (token) flow.
|
||||
func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error {
|
||||
want := make(map[string]bool)
|
||||
for _, s := range v {
|
||||
want[s] = true
|
||||
}
|
||||
got := make(map[string]bool)
|
||||
for _, s := range req.DNSNames {
|
||||
got[s] = true
|
||||
}
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
return errors.Errorf("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ipAddressesValidator validates the IP addresses SAN of a certificate request.
|
||||
type ipAddressesValidator []net.IP
|
||||
|
||||
// Valid checks that certificate request IP Addresses match those configured in
|
||||
// the bootstrap (token) flow.
|
||||
func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error {
|
||||
want := make(map[string]bool)
|
||||
for _, ip := range v {
|
||||
want[ip.String()] = true
|
||||
}
|
||||
got := make(map[string]bool)
|
||||
for _, ip := range req.IPAddresses {
|
||||
got[ip.String()] = true
|
||||
}
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
return errors.Errorf("IP Addresses claim failed - got %v, want %v", req.IPAddresses, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validityValidator validates the certificate temporal validity settings.
|
||||
type validityValidator struct {
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
}
|
||||
|
||||
// newValidityValidator return a new validity validator.
|
||||
func newValidityValidator(min, max time.Duration) *validityValidator {
|
||||
return &validityValidator{min: min, max: max}
|
||||
}
|
||||
|
||||
// Validate validates the certificate temporal validity settings.
|
||||
func (v *validityValidator) Valid(crt *x509.Certificate) error {
|
||||
var (
|
||||
na = crt.NotAfter
|
||||
nb = crt.NotBefore
|
||||
d = na.Sub(nb)
|
||||
now = time.Now()
|
||||
)
|
||||
|
||||
if na.Before(now) {
|
||||
return errors.Errorf("NotAfter: %v cannot be in the past", na)
|
||||
}
|
||||
if na.Before(nb) {
|
||||
return errors.Errorf("NotAfter: %v cannot be before NotBefore: %v", na, nb)
|
||||
}
|
||||
if d < v.min {
|
||||
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v",
|
||||
d, v.min)
|
||||
}
|
||||
if d > v.max {
|
||||
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
|
||||
d, v.max)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
||||
)
|
||||
|
||||
type stepProvisionerASN1 struct {
|
||||
Type int
|
||||
Name []byte
|
||||
CredentialID []byte
|
||||
}
|
||||
|
||||
type provisionerExtensionOption struct {
|
||||
Type int
|
||||
Name string
|
||||
CredentialID string
|
||||
}
|
||||
|
||||
func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisionerExtensionOption {
|
||||
return &provisionerExtensionOption{
|
||||
Type: int(typ),
|
||||
Name: name,
|
||||
CredentialID: credentialID,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
|
||||
return func(p x509util.Profile) error {
|
||||
crt := p.Subject()
|
||||
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
crt.ExtraExtensions = append(crt.ExtraExtensions, ext)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) {
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||
Type: typ,
|
||||
Name: []byte(name),
|
||||
CredentialID: []byte(credentialID),
|
||||
})
|
||||
if err != nil {
|
||||
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
|
||||
}
|
||||
return pkix.Extension{
|
||||
Id: stepOIDProvisioner,
|
||||
Critical: false,
|
||||
Value: b,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Avoid deadcode warning in profileWithOption
|
||||
_ = profileWithOption(nil)
|
||||
}
|
152
authority/provisioner/sign_options_test.go
Normal file
152
authority/provisioner/sign_options_test.go
Normal file
|
@ -0,0 +1,152 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_emailOnlyIdentity_Valid(t *testing.T) {
|
||||
uri, err := url.Parse("https://example.com/1.0/getUser")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
e emailOnlyIdentity
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com"}}}, false},
|
||||
{"DNSNames", "name@smallstep.com", args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, true},
|
||||
{"IPAddresses", "name@smallstep.com", args{&x509.CertificateRequest{IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}}}, true},
|
||||
{"URIs", "name@smallstep.com", args{&x509.CertificateRequest{URIs: []*url.URL{uri}}}, true},
|
||||
{"no-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{}}}, true},
|
||||
{"empty-email", "", args{&x509.CertificateRequest{EmailAddresses: []string{""}}}, true},
|
||||
{"multiple-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com", "foo@smallstep.com"}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.e.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("emailOnlyIdentity.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_commonNameValidator_Valid(t *testing.T) {
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
v commonNameValidator
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false},
|
||||
{"empty", "", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: ""}}}, true},
|
||||
{"wrong", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "example.com"}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("commonNameValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_dnsNamesValidator_Valid(t *testing.T) {
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
v dnsNamesValidator
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok0", []string{}, args{&x509.CertificateRequest{DNSNames: []string{}}}, false},
|
||||
{"ok1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, false},
|
||||
{"ok2", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "bar.zar"}}}, false},
|
||||
{"ok3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, false},
|
||||
{"fail1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar"}}}, true},
|
||||
{"fail2", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, true},
|
||||
{"fail3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "zar.bar"}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("dnsNamesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ipAddressesValidator_Valid(t *testing.T) {
|
||||
ip1 := net.IPv4(10, 3, 2, 1)
|
||||
ip2 := net.IPv4(10, 3, 2, 2)
|
||||
ip3 := net.IPv4(10, 3, 2, 3)
|
||||
|
||||
type args struct {
|
||||
req *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
v ipAddressesValidator
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok0", []net.IP{}, args{&x509.CertificateRequest{IPAddresses: []net.IP{}}}, false},
|
||||
{"ok1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1}}}, false},
|
||||
{"ok2", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip2}}}, false},
|
||||
{"ok3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, false},
|
||||
{"fail1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2}}}, true},
|
||||
{"fail2", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, true},
|
||||
{"fail3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip3}}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ipAddressesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validityValidator_Valid(t *testing.T) {
|
||||
type fields struct {
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
}
|
||||
type args struct {
|
||||
crt *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := &validityValidator{
|
||||
min: tt.fields.min,
|
||||
max: tt.fields.max,
|
||||
}
|
||||
if err := v.Valid(tt.args.crt); (err != nil) != tt.wantErr {
|
||||
t.Errorf("validityValidator.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
11
authority/provisioner/testdata/root_ca.crt
vendored
Normal file
11
authority/provisioner/testdata/root_ca.crt
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf
|
||||
MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla
|
||||
Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg
|
||||
Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN
|
||||
Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw
|
||||
QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU
|
||||
B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c
|
||||
ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET
|
||||
/A8LXNH4M06A7vE=
|
||||
-----END CERTIFICATE-----
|
272
authority/provisioner/utils_test.go
Normal file
272
authority/provisioner/utils_test.go
Normal file
|
@ -0,0 +1,272 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
var testAudiences = []string{
|
||||
"https://ca.smallstep.com/sign",
|
||||
"https://ca.smallsteomcom/1.0/sign",
|
||||
}
|
||||
|
||||
func must(args ...interface{}) []interface{} {
|
||||
if l := len(args); l > 0 && args[l-1] != nil {
|
||||
if err, ok := args[l-1].(error); ok {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func generateJSONWebKey() (*jose.JSONWebKey, error) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fp, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk.KeyID = string(hex.EncodeToString(fp))
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) {
|
||||
var keySet jose.JSONWebKeySet
|
||||
for i := 0; i < n; i++ {
|
||||
key, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return jose.JSONWebKeySet{}, err
|
||||
}
|
||||
keySet.Keys = append(keySet.Keys, *key)
|
||||
}
|
||||
return keySet, nil
|
||||
}
|
||||
|
||||
func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) {
|
||||
b, err := json.Marshal(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts := new(jose.EncrypterOptions)
|
||||
opts.WithContentType(jose.ContentType("jwk+json"))
|
||||
recipient := jose.Recipient{
|
||||
Algorithm: jose.PBES2_HS256_A128KW,
|
||||
Key: []byte("password"),
|
||||
PBES2Count: jose.PBKDF2Iterations,
|
||||
PBES2Salt: salt,
|
||||
}
|
||||
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encrypter.Encrypt(b)
|
||||
}
|
||||
|
||||
func decryptJSONWebKey(key string) (*jose.JSONWebKey, error) {
|
||||
enc, err := jose.ParseEncrypted(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, err := enc.Decrypt([]byte("password"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk := new(jose.JSONWebKey)
|
||||
if err := json.Unmarshal(b, jwk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
func generateJWK() (*JWK, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwe, err := encryptJSONWebKey(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
public := jwk.Public()
|
||||
encrypted, err := jwe.CompactSerialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &JWK{
|
||||
Name: name,
|
||||
Type: "JWK",
|
||||
Key: &public,
|
||||
EncryptedKey: encrypted,
|
||||
Claims: &globalProvisionerClaims,
|
||||
audiences: testAudiences,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateOIDC() (*OIDC, error) {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientID, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
issuer, err := randutil.Alphanumeric(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OIDC{
|
||||
Name: name,
|
||||
Type: "OIDC",
|
||||
ClientID: clientID,
|
||||
ConfigurationEndpoint: "https://example.com/.well-known/openid-configuration",
|
||||
Claims: &globalProvisionerClaims,
|
||||
configuration: openIDConfiguration{
|
||||
Issuer: issuer,
|
||||
JWKSetURI: "https://example.com/.well-known/jwks",
|
||||
},
|
||||
keyStore: &keyStore{
|
||||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||
expiry: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateCollection(nJWK, nOIDC int) (*Collection, error) {
|
||||
col := NewCollection(testAudiences)
|
||||
for i := 0; i < nJWK; i++ {
|
||||
p, err := generateJWK()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
col.Store(p)
|
||||
}
|
||||
for i := 0; i < nOIDC; i++ {
|
||||
p, err := generateOIDC()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
col.Store(p)
|
||||
}
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
|
||||
return generateToken("subject", iss, aud, "name@smallstep.com", []string{"test.smallstep.com"}, time.Now(), jwk)
|
||||
}
|
||||
|
||||
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
id, err := randutil.ASCII(64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims := struct {
|
||||
jose.Claims
|
||||
Email string `json:"email"`
|
||||
SANS []string `json:"sans"`
|
||||
}{
|
||||
Claims: jose.Claims{
|
||||
ID: id,
|
||||
Subject: sub,
|
||||
Issuer: iss,
|
||||
IssuedAt: jose.NewNumericDate(iat),
|
||||
NotBefore: jose.NewNumericDate(iat),
|
||||
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
|
||||
Audience: []string{aud},
|
||||
},
|
||||
Email: email,
|
||||
SANS: sans,
|
||||
}
|
||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
|
||||
tok, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
claims := new(jose.Claims)
|
||||
if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return tok, claims, nil
|
||||
}
|
||||
|
||||
func generateJWKServer(n int) *httptest.Server {
|
||||
hits := struct {
|
||||
Hits int `json:"hits"`
|
||||
}{}
|
||||
writeJSON := func(w http.ResponseWriter, v interface{}) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(b)
|
||||
}
|
||||
getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet {
|
||||
var ret jose.JSONWebKeySet
|
||||
for _, k := range ks.Keys {
|
||||
ret.Keys = append(ret.Keys, k.Public())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||
srv := httptest.NewUnstartedServer(nil)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits.Hits++
|
||||
switch r.RequestURI {
|
||||
case "/error":
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
case "/hits":
|
||||
writeJSON(w, hits)
|
||||
case "/openid-configuration", "/.well-known/openid-configuration":
|
||||
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
|
||||
case "/random":
|
||||
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, getPublic(keySet))
|
||||
case "/private":
|
||||
writeJSON(w, defaultKeySet)
|
||||
default:
|
||||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, getPublic(defaultKeySet))
|
||||
}
|
||||
})
|
||||
|
||||
srv.Start()
|
||||
return srv
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestProvisionerInit(t *testing.T) {
|
||||
type ProvisionerValidateTest struct {
|
||||
p *Provisioner
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) ProvisionerValidateTest{
|
||||
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &Provisioner{},
|
||||
err: errors.New("provisioner name cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &Provisioner{Name: "foo"},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-key": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &Provisioner{Name: "foo", Type: "bar"},
|
||||
err: errors.New("provisioner key cannot be empty"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &Provisioner{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for name, get := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := get(t)
|
||||
err := tc.p.Init(&globalProvisionerClaims)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,115 +1,25 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||
const DefaultProvisionersLimit = 20
|
||||
|
||||
// DefaultProvisionersMax is the maximum limit for listing provisioners.
|
||||
const DefaultProvisionersMax = 100
|
||||
|
||||
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
||||
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||
val, ok := a.encryptedKeyIndex.Load(kid)
|
||||
key, ok := a.provisioners.LoadEncryptedKey(kid)
|
||||
if !ok {
|
||||
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
|
||||
http.StatusNotFound, context{}}
|
||||
}
|
||||
|
||||
key, ok := val.(string)
|
||||
if !ok {
|
||||
return "", &apiError{errors.Errorf("stored value is not a string"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetProvisioners returns a map listing each provisioner and the JWK Key Set
|
||||
// with their public keys.
|
||||
func (a *Authority) GetProvisioners(cursor string, limit int) ([]*Provisioner, string, error) {
|
||||
provisioners, nextCursor := a.sortedProvisioners.Find(cursor, limit)
|
||||
func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) {
|
||||
provisioners, nextCursor := a.provisioners.Find(cursor, limit)
|
||||
return provisioners, nextCursor, nil
|
||||
}
|
||||
|
||||
type uidProvisioner struct {
|
||||
provisioner *Provisioner
|
||||
uid string
|
||||
}
|
||||
|
||||
func newSortedProvisioners(provisioners []*Provisioner) (provisionerSlice, error) {
|
||||
if len(provisioners) > math.MaxInt32 {
|
||||
return nil, errors.New("too many provisioners")
|
||||
}
|
||||
|
||||
var slice provisionerSlice
|
||||
bi := make([]byte, 4)
|
||||
for i, p := range provisioners {
|
||||
sum, err := provisionerSum(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Use the first 4 bytes (32bit) of the sum to insert the order
|
||||
// Using big endian format to get the strings sorted:
|
||||
// 0x00000000, 0x00000001, 0x00000002, ...
|
||||
binary.BigEndian.PutUint32(bi, uint32(i))
|
||||
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
||||
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
|
||||
slice = append(slice, uidProvisioner{
|
||||
provisioner: p,
|
||||
uid: hex.EncodeToString(sum),
|
||||
})
|
||||
}
|
||||
sort.Sort(slice)
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
type provisionerSlice []uidProvisioner
|
||||
|
||||
func (p provisionerSlice) Len() int { return len(p) }
|
||||
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
func (p provisionerSlice) Find(cursor string, limit int) ([]*Provisioner, string) {
|
||||
switch {
|
||||
case limit <= 0:
|
||||
limit = DefaultProvisionersLimit
|
||||
case limit > DefaultProvisionersMax:
|
||||
limit = DefaultProvisionersMax
|
||||
}
|
||||
|
||||
n := len(p)
|
||||
cursor = fmt.Sprintf("%040s", cursor)
|
||||
i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor })
|
||||
|
||||
var slice []*Provisioner
|
||||
for ; i < n && len(slice) < limit; i++ {
|
||||
slice = append(slice, p[i].provisioner)
|
||||
}
|
||||
if i < n {
|
||||
return slice, strings.TrimLeft(p[i].uid, "0")
|
||||
}
|
||||
return slice, ""
|
||||
}
|
||||
|
||||
// provisionerSum returns the SHA1 of the json representation of the
|
||||
// provisioner. From this we will create the unique and sorted id.
|
||||
func provisionerSum(p *Provisioner) ([]byte, error) {
|
||||
b, err := json.Marshal(p.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshalling provisioner")
|
||||
}
|
||||
sum := sha1.Sum(b)
|
||||
return sum[:], nil
|
||||
}
|
||||
|
|
|
@ -1,16 +1,12 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
func TestGetEncryptedKey(t *testing.T) {
|
||||
|
@ -27,7 +23,7 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return &ek{
|
||||
a: a,
|
||||
kid: c.AuthorityConfig.Provisioners[1].Key.KeyID,
|
||||
kid: c.AuthorityConfig.Provisioners[1].(*provisioner.JWK).Key.KeyID,
|
||||
}
|
||||
},
|
||||
"fail-not-found": func(t *testing.T) *ek {
|
||||
|
@ -42,19 +38,6 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
http.StatusNotFound, context{}},
|
||||
}
|
||||
},
|
||||
"fail-invalid-type-found": func(t *testing.T) *ek {
|
||||
c, err := LoadConfiguration("../ca/testdata/ca.json")
|
||||
assert.FatalError(t, err)
|
||||
a, err := New(c)
|
||||
assert.FatalError(t, err)
|
||||
a.encryptedKeyIndex.Store("foo", 5)
|
||||
return &ek{
|
||||
a: a,
|
||||
kid: "foo",
|
||||
err: &apiError{errors.Errorf("stored value is not a string"),
|
||||
http.StatusInternalServerError, context{}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for name, genTestCase := range tests {
|
||||
|
@ -75,9 +58,9 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
val, ok := tc.a.provisionerIDIndex.Load("max:" + tc.kid)
|
||||
val, ok := tc.a.provisioners.Load("max:" + tc.kid)
|
||||
assert.Fatal(t, ok)
|
||||
p, ok := val.(*Provisioner)
|
||||
p, ok := val.(*provisioner.JWK)
|
||||
assert.Fatal(t, ok)
|
||||
assert.Equals(t, p.EncryptedKey, ek)
|
||||
}
|
||||
|
@ -126,102 +109,3 @@ func TestGetProvisioners(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateProvisioner(t *testing.T) *Provisioner {
|
||||
name, err := randutil.Alphanumeric(10)
|
||||
assert.FatalError(t, err)
|
||||
// Create a new JWK
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
// Encrypt JWK
|
||||
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(jwk)
|
||||
assert.FatalError(t, err)
|
||||
recipient := jose.Recipient{
|
||||
Algorithm: jose.PBES2_HS256_A128KW,
|
||||
Key: []byte("password"),
|
||||
PBES2Count: jose.PBKDF2Iterations,
|
||||
PBES2Salt: salt,
|
||||
}
|
||||
opts := new(jose.EncrypterOptions)
|
||||
opts.WithContentType(jose.ContentType("jwk+json"))
|
||||
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
|
||||
assert.FatalError(t, err)
|
||||
jwe, err := encrypter.Encrypt(b)
|
||||
assert.FatalError(t, err)
|
||||
// get public and encrypted keys
|
||||
public := jwk.Public()
|
||||
encrypted, err := jwe.CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
return &Provisioner{
|
||||
Name: name,
|
||||
Type: "JWT",
|
||||
Key: &public,
|
||||
EncryptedKey: encrypted,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newSortedProvisioners(t *testing.T) {
|
||||
provisioners := make([]*Provisioner, 20)
|
||||
for i := range provisioners {
|
||||
provisioners[i] = generateProvisioner(t)
|
||||
}
|
||||
|
||||
ps, err := newSortedProvisioners(provisioners)
|
||||
assert.FatalError(t, err)
|
||||
prev := ""
|
||||
for i, p := range ps {
|
||||
if p.uid < prev {
|
||||
t.Errorf("%s should be less that %s", p.uid, prev)
|
||||
}
|
||||
if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID {
|
||||
t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID)
|
||||
}
|
||||
prev = p.uid
|
||||
}
|
||||
}
|
||||
|
||||
func Test_provisionerSlice_Find(t *testing.T) {
|
||||
trim := func(s string) string {
|
||||
return strings.TrimLeft(s, "0")
|
||||
}
|
||||
provisioners := make([]*Provisioner, 20)
|
||||
for i := range provisioners {
|
||||
provisioners[i] = generateProvisioner(t)
|
||||
}
|
||||
ps, err := newSortedProvisioners(provisioners)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cursor string
|
||||
limit int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
p provisionerSlice
|
||||
args args
|
||||
want []*Provisioner
|
||||
want1 string
|
||||
}{
|
||||
{"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""},
|
||||
{"0 to 19", ps, args{"", 20}, provisioners[0:20], ""},
|
||||
{"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)},
|
||||
{"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""},
|
||||
{"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)},
|
||||
{"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)},
|
||||
{"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""},
|
||||
{"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package authority
|
|||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
|
@ -11,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
|
@ -22,48 +22,7 @@ func (a *Authority) GetTLSOptions() *tlsutil.TLSOptions {
|
|||
return a.config.TLS
|
||||
}
|
||||
|
||||
// SignOptions contains the options that can be passed to the Authority.Sign
|
||||
// method.
|
||||
type SignOptions struct {
|
||||
NotAfter time.Time `json:"notAfter"`
|
||||
NotBefore time.Time `json:"notBefore"`
|
||||
}
|
||||
|
||||
var (
|
||||
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
||||
oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35}
|
||||
)
|
||||
|
||||
type stepProvisionerASN1 struct {
|
||||
Type int
|
||||
Name []byte
|
||||
CredentialID []byte
|
||||
}
|
||||
|
||||
const provisionerTypeJWK = 1
|
||||
|
||||
func withProvisionerOID(name, kid string) x509util.WithOption {
|
||||
return func(p x509util.Profile) error {
|
||||
crt := p.Subject()
|
||||
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||
Type: provisionerTypeJWK,
|
||||
Name: []byte(name),
|
||||
CredentialID: []byte(kid),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{
|
||||
Id: stepOIDProvisioner,
|
||||
Critical: false,
|
||||
Value: b,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
var oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35}
|
||||
|
||||
func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
|
||||
return func(p x509util.Profile) error {
|
||||
|
@ -96,28 +55,22 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
|
|||
}
|
||||
|
||||
// Sign creates a signed certificate from a certificate signing request.
|
||||
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) {
|
||||
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) {
|
||||
var (
|
||||
errContext = context{"csr": csr, "signOptions": signOpts}
|
||||
claims = []certClaim{}
|
||||
mods = []x509util.WithOption{}
|
||||
errContext = context{"csr": csr, "signOptions": signOpts}
|
||||
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
|
||||
certValidators = []provisioner.CertificateValidator{}
|
||||
)
|
||||
for _, op := range extraOpts {
|
||||
switch k := op.(type) {
|
||||
case certClaim:
|
||||
claims = append(claims, k)
|
||||
case x509util.WithOption:
|
||||
mods = append(mods, k)
|
||||
case *Provisioner:
|
||||
m, c, err := k.getTLSApps(signOpts)
|
||||
if err != nil {
|
||||
return nil, nil, &apiError{err, http.StatusInternalServerError, errContext}
|
||||
case provisioner.CertificateValidator:
|
||||
certValidators = append(certValidators, k)
|
||||
case provisioner.CertificateRequestValidator:
|
||||
if err := k.Valid(csr); err != nil {
|
||||
return nil, nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
mods = append(mods, m...)
|
||||
mods = append(mods, []x509util.WithOption{
|
||||
withDefaultASN1DN(a.config.AuthorityConfig.Template),
|
||||
}...)
|
||||
claims = append(claims, c...)
|
||||
case provisioner.ProfileModifier:
|
||||
mods = append(mods, k.Option(signOpts))
|
||||
default:
|
||||
return nil, nil, &apiError{errors.Errorf("sign: invalid extra option type %T", k),
|
||||
http.StatusInternalServerError, errContext}
|
||||
|
@ -137,10 +90,6 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, ext
|
|||
return nil, nil, &apiError{errors.Wrapf(err, "sign"), http.StatusInternalServerError, errContext}
|
||||
}
|
||||
|
||||
if err := validateClaims(leaf.Subject(), claims); err != nil {
|
||||
return nil, nil, &apiError{errors.Wrapf(err, "sign"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
|
||||
crtBytes, err := leaf.CreateCertificate()
|
||||
if err != nil {
|
||||
return nil, nil, &apiError{errors.Wrap(err, "sign: error creating new leaf certificate"),
|
||||
|
@ -153,6 +102,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts SignOptions, ext
|
|||
http.StatusInternalServerError, errContext}
|
||||
}
|
||||
|
||||
// FIXME: This should be before creating the certificate.
|
||||
for _, v := range certValidators {
|
||||
if err := v.Valid(serverCert); err != nil {
|
||||
return nil, nil, &apiError{errors.Wrap(err, "sign"), http.StatusUnauthorized, errContext}
|
||||
}
|
||||
}
|
||||
|
||||
caCert, err := x509.ParseCertificate(issIdentity.Crt.Raw)
|
||||
if err != nil {
|
||||
return nil, nil, &apiError{errors.Wrap(err, "sign: error parsing intermediate certificate"),
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
@ -15,12 +14,49 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
"github.com/smallstep/cli/jose"
|
||||
stepx509 "github.com/smallstep/cli/pkg/x509"
|
||||
)
|
||||
|
||||
var (
|
||||
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
||||
)
|
||||
|
||||
const provisionerTypeJWK = 1
|
||||
|
||||
type stepProvisionerASN1 struct {
|
||||
Type int
|
||||
Name []byte
|
||||
CredentialID []byte
|
||||
}
|
||||
|
||||
func withProvisionerOID(name, kid string) x509util.WithOption {
|
||||
return func(p x509util.Profile) error {
|
||||
crt := p.Subject()
|
||||
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||
Type: provisionerTypeJWK,
|
||||
Name: []byte(name),
|
||||
CredentialID: []byte(kid),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{
|
||||
Id: stepOIDProvisioner,
|
||||
Critical: false,
|
||||
Value: b,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateRequest)) *x509.CertificateRequest {
|
||||
_csr := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{CommonName: "smallstep test"},
|
||||
|
@ -52,24 +88,25 @@ func TestSign(t *testing.T) {
|
|||
}
|
||||
|
||||
nb := time.Now()
|
||||
signOpts := SignOptions{
|
||||
signOpts := provisioner.Options{
|
||||
NotBefore: nb,
|
||||
NotAfter: nb.Add(time.Minute * 5),
|
||||
}
|
||||
|
||||
p := a.config.AuthorityConfig.Provisioners[1]
|
||||
extraOpts := []interface{}{
|
||||
&commonNameClaim{"smallstep test"},
|
||||
&dnsNamesClaim{[]string{"test.smallstep.com"}},
|
||||
&ipAddressesClaim{[]net.IP{}},
|
||||
p,
|
||||
}
|
||||
// Create a token to get test extra opts.
|
||||
p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK)
|
||||
key, err := jose.ParseKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||
assert.FatalError(t, err)
|
||||
token, err := generateToken("smallstep test", "step-cli", "https://test.ca.smallstep.com/sign", []string{"test.smallstep.com"}, time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
extraOpts, err := a.Authorize(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type signTest struct {
|
||||
auth *Authority
|
||||
csr *x509.CertificateRequest
|
||||
signOpts SignOptions
|
||||
extraOpts []interface{}
|
||||
signOpts provisioner.Options
|
||||
extraOpts []provisioner.SignOption
|
||||
err *apiError
|
||||
}
|
||||
tests := map[string]func(*testing.T) *signTest{
|
||||
|
@ -123,7 +160,7 @@ func TestSign(t *testing.T) {
|
|||
return &signTest{
|
||||
auth: _a,
|
||||
csr: csr,
|
||||
extraOpts: []interface{}{p},
|
||||
extraOpts: extraOpts,
|
||||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: error creating new leaf certificate"),
|
||||
http.StatusInternalServerError,
|
||||
|
@ -133,7 +170,7 @@ func TestSign(t *testing.T) {
|
|||
},
|
||||
"fail provisioner duration claim": func(t *testing.T) *signTest {
|
||||
csr := getCSR(t, priv)
|
||||
_signOpts := SignOptions{
|
||||
_signOpts := provisioner.Options{
|
||||
NotBefore: nb,
|
||||
NotAfter: nb.Add(time.Hour * 25),
|
||||
}
|
||||
|
@ -157,7 +194,7 @@ func TestSign(t *testing.T) {
|
|||
csr: csr,
|
||||
extraOpts: extraOpts,
|
||||
signOpts: signOpts,
|
||||
err: &apiError{errors.New("sign: DNS names claim failed - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
||||
err: &apiError{errors.New("sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
|
||||
http.StatusUnauthorized,
|
||||
context{"csr": csr, "signOptions": signOpts},
|
||||
},
|
||||
|
@ -262,7 +299,7 @@ func TestRenew(t *testing.T) {
|
|||
now := time.Now().UTC()
|
||||
nb1 := now.Add(-time.Minute * 7)
|
||||
na1 := now
|
||||
so := &SignOptions{
|
||||
so := &provisioner.Options{
|
||||
NotBefore: nb1,
|
||||
NotAfter: na1,
|
||||
}
|
||||
|
@ -272,7 +309,7 @@ func TestRenew(t *testing.T) {
|
|||
x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0),
|
||||
withDefaultASN1DN(a.config.AuthorityConfig.Template),
|
||||
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
|
||||
withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].Key.KeyID))
|
||||
withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID))
|
||||
assert.FatalError(t, err)
|
||||
crtBytes, err := leaf.CreateCertificate()
|
||||
assert.FatalError(t, err)
|
||||
|
@ -284,7 +321,7 @@ func TestRenew(t *testing.T) {
|
|||
x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0),
|
||||
withDefaultASN1DN(a.config.AuthorityConfig.Template),
|
||||
x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"),
|
||||
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID),
|
||||
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID),
|
||||
)
|
||||
assert.FatalError(t, err)
|
||||
crtBytesNoRenew, err := leafNoRenew.CreateCertificate()
|
||||
|
@ -321,7 +358,7 @@ func TestRenew(t *testing.T) {
|
|||
}
|
||||
return &renewTest{
|
||||
crt: crtNoRenew,
|
||||
err: &apiError{errors.New("renew disabled"),
|
||||
err: &apiError{errors.New("renew is disabled for provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
|
||||
http.StatusUnauthorized, ctx},
|
||||
}, nil
|
||||
},
|
||||
|
|
|
@ -2,48 +2,10 @@ package authority
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
// MarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.Duration.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *Duration) UnmarshalJSON(data []byte) (err error) {
|
||||
var (
|
||||
s string
|
||||
_d time.Duration
|
||||
)
|
||||
if d == nil {
|
||||
return errors.New("duration cannot be nil")
|
||||
}
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
}
|
||||
if _d, err = time.ParseDuration(s); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||
}
|
||||
d.Duration = _d
|
||||
return
|
||||
}
|
||||
|
||||
// multiString represents a type that can be encoded/decoded in JSON as a single
|
||||
// string or an array of strings.
|
||||
type multiString []string
|
||||
|
|
|
@ -3,7 +3,6 @@ package authority
|
|||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_multiString_First(t *testing.T) {
|
||||
|
@ -101,57 +100,3 @@ func Test_multiString_UnmarshalJSON(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration_UnmarshalJSON(t *testing.T) {
|
||||
type args struct {
|
||||
data []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
d *Duration
|
||||
args args
|
||||
want *Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty", new(Duration), args{[]byte{}}, new(Duration), true},
|
||||
{"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true},
|
||||
{"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true},
|
||||
{"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true},
|
||||
{"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false},
|
||||
{"nil", nil, args{nil}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(tt.d, tt.want) {
|
||||
t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_duration_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
d *Duration
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.d.MarshalJSON()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
provisioners "github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/ca"
|
||||
"github.com/smallstep/cli/config"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
|
@ -111,10 +111,12 @@ func loadProvisionerJWKByName(name, caURL, caRoot, passFile string) (key *jose.J
|
|||
}
|
||||
|
||||
for _, provisioner := range provisioners {
|
||||
if provisioner.Name == name {
|
||||
key, err = decryptProvisionerJWK(provisioner.EncryptedKey, passFile)
|
||||
if err == nil {
|
||||
return
|
||||
if provisioner.GetName() == name {
|
||||
if _, encryptedKey, ok := provisioner.GetEncryptedKey(); ok {
|
||||
key, err = decryptProvisionerJWK(encryptedKey, passFile)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,7 +156,7 @@ func getRootCAPath() string {
|
|||
}
|
||||
|
||||
// getProvisioners returns the map of provisioners on the given CA.
|
||||
func getProvisioners(caURL, rootFile string) ([]*authority.Provisioner, error) {
|
||||
func getProvisioners(caURL, rootFile string) (provisioners.List, error) {
|
||||
if len(rootFile) == 0 {
|
||||
rootFile = getRootCAPath()
|
||||
}
|
||||
|
@ -163,7 +165,7 @@ func getProvisioners(caURL, rootFile string) ([]*authority.Provisioner, error) {
|
|||
return nil, err
|
||||
}
|
||||
cursor := ""
|
||||
provisioners := []*authority.Provisioner{}
|
||||
var provisioners provisioners.List
|
||||
for {
|
||||
resp, err := client.Provisioners(ca.WithProvisionerCursor(cursor), ca.WithProvisionerLimit(100))
|
||||
if err != nil {
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
|
@ -389,7 +390,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *ekt {
|
||||
p := config.AuthorityConfig.Provisioners[2]
|
||||
p := config.AuthorityConfig.Provisioners[2].(*provisioner.JWK)
|
||||
return &ekt{
|
||||
ca: ca,
|
||||
kid: p.Key.KeyID,
|
||||
|
|
|
@ -446,7 +446,11 @@ func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error)
|
|||
return nil, nil, errors.Wrap(err, "error generating key")
|
||||
}
|
||||
|
||||
var emails []string
|
||||
dnsNames, ips := x509util.SplitSANs(claims.SANs)
|
||||
if claims.Email != "" {
|
||||
emails = append(emails, claims.Email)
|
||||
}
|
||||
|
||||
template := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
|
@ -455,6 +459,7 @@ func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error)
|
|||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
DNSNames: dnsNames,
|
||||
IPAddresses: ips,
|
||||
EmailAddresses: emails,
|
||||
}
|
||||
|
||||
csr, err := x509.CreateCertificateRequest(rand.Reader, template, pk)
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -391,7 +391,7 @@ func TestClient_Renew(t *testing.T) {
|
|||
|
||||
func TestClient_Provisioners(t *testing.T) {
|
||||
ok := &api.ProvisionersResponse{
|
||||
Provisioners: []*authority.Provisioner{},
|
||||
Provisioners: provisioner.List{},
|
||||
}
|
||||
internalServerError := api.InternalServerError(fmt.Errorf("Internal Server Error"))
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
@ -143,40 +144,12 @@ intermediate private key.`,
|
|||
}
|
||||
|
||||
app.Action = func(ctx *cli.Context) error {
|
||||
passFile := ctx.String("password-file")
|
||||
|
||||
// If zero cmd line args show help, if >1 cmd line args show error.
|
||||
if ctx.NArg() == 0 {
|
||||
return cli.ShowAppHelp(ctx)
|
||||
}
|
||||
if err := errs.NumberOfArguments(ctx, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configFile := ctx.Args().Get(0)
|
||||
config, err := authority.LoadConfiguration(configFile)
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
|
||||
var password []byte
|
||||
if passFile != "" {
|
||||
if password, err = ioutil.ReadFile(passFile); err != nil {
|
||||
fatal(errors.Wrapf(err, "error reading %s", passFile))
|
||||
}
|
||||
password = bytes.TrimRightFunc(password, unicode.IsSpace)
|
||||
}
|
||||
|
||||
srv, err := ca.New(config, ca.WithConfigFile(configFile), ca.WithPassword(password))
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
|
||||
go ca.StopReloaderHandler(srv)
|
||||
if err = srv.Run(); err != nil && err != http.ErrServerClosed {
|
||||
fatal(err)
|
||||
}
|
||||
return nil
|
||||
// Hack to be able to run a the top action as a subcommand
|
||||
cmd := cli.Command{Name: "start", Action: startAction, Flags: app.Flags}
|
||||
set := flag.NewFlagSet(app.Name, flag.ContinueOnError)
|
||||
set.Parse(os.Args)
|
||||
ctx = cli.NewContext(app, set, nil)
|
||||
return cmd.Run(ctx)
|
||||
}
|
||||
|
||||
if err := app.Run(os.Args); err != nil {
|
||||
|
@ -189,6 +162,43 @@ intermediate private key.`,
|
|||
}
|
||||
}
|
||||
|
||||
func startAction(ctx *cli.Context) error {
|
||||
passFile := ctx.String("password-file")
|
||||
|
||||
// If zero cmd line args show help, if >1 cmd line args show error.
|
||||
if ctx.NArg() == 0 {
|
||||
return cli.ShowAppHelp(ctx)
|
||||
}
|
||||
if err := errs.NumberOfArguments(ctx, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configFile := ctx.Args().Get(0)
|
||||
config, err := authority.LoadConfiguration(configFile)
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
|
||||
var password []byte
|
||||
if passFile != "" {
|
||||
if password, err = ioutil.ReadFile(passFile); err != nil {
|
||||
fatal(errors.Wrapf(err, "error reading %s", passFile))
|
||||
}
|
||||
password = bytes.TrimRightFunc(password, unicode.IsSpace)
|
||||
}
|
||||
|
||||
srv, err := ca.New(config, ca.WithConfigFile(configFile), ca.WithPassword(password))
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
|
||||
go ca.StopReloaderHandler(srv)
|
||||
if err = srv.Run(); err != nil && err != http.ErrServerClosed {
|
||||
fatal(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fatal writes the passed error on the standard error and exits with the exit
|
||||
// code 1. If the environment variable STEPDEBUG is set to 1 it shows the
|
||||
// stack trace of the error.
|
||||
|
|
Loading…
Reference in a new issue