Add initial implementation of an OIDC provisioner.
This commit is contained in:
parent
98b3d971f6
commit
a2a45f635b
4 changed files with 341 additions and 0 deletions
17
authority/provisioner/claims.go
Normal file
17
authority/provisioner/claims.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package provisioner
|
||||
|
||||
import "time"
|
||||
|
||||
// Claims so that individual provisioners can override global claims.
|
||||
type Claims struct {
|
||||
globalClaims *Claims
|
||||
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
||||
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
||||
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
||||
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||
}
|
||||
|
||||
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
122
authority/provisioner/keystore.go
Normal file
122
authority/provisioner/keystore.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCacheAge = 12 * time.Hour
|
||||
defaultCacheJitter = 1 * time.Hour
|
||||
)
|
||||
|
||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
||||
|
||||
type keyStore struct {
|
||||
sync.RWMutex
|
||||
uri string
|
||||
keys jose.JSONWebKeySet
|
||||
timer *time.Timer
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
func newKeyStore(uri string) (*keyStore, error) {
|
||||
keys, age, err := getKeysFromJWKsURI(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ks := &keyStore{
|
||||
uri: uri,
|
||||
keys: keys,
|
||||
expiry: getExpirationTime(age),
|
||||
}
|
||||
ks.timer = time.AfterFunc(age, ks.reload)
|
||||
return ks, nil
|
||||
}
|
||||
|
||||
func (ks *keyStore) Close() {
|
||||
ks.timer.Stop()
|
||||
}
|
||||
|
||||
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
||||
ks.RLock()
|
||||
// Force reload if expiration has passed
|
||||
if time.Now().After(ks.expiry) {
|
||||
ks.RUnlock()
|
||||
ks.reload()
|
||||
ks.RLock()
|
||||
}
|
||||
keys = ks.keys.Key(kid)
|
||||
ks.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (ks *keyStore) reload() {
|
||||
var next time.Duration
|
||||
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
||||
if err != nil {
|
||||
next = ks.nextReloadDuration(defaultCacheJitter / 2)
|
||||
} else {
|
||||
ks.Lock()
|
||||
ks.keys = keys
|
||||
ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
|
||||
ks.Unlock()
|
||||
next = ks.nextReloadDuration(age)
|
||||
}
|
||||
|
||||
ks.Lock()
|
||||
ks.timer.Reset(next)
|
||||
ks.Unlock()
|
||||
}
|
||||
|
||||
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||
n := rand.Int63n(int64(defaultCacheJitter))
|
||||
age -= time.Duration(n)
|
||||
if age < 0 {
|
||||
age = 0
|
||||
}
|
||||
return age
|
||||
}
|
||||
|
||||
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||
var keys jose.JSONWebKeySet
|
||||
resp, err := http.Get(uri)
|
||||
if err != nil {
|
||||
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
|
||||
return keys, 0, errors.Wrapf(err, "error reading %s", uri)
|
||||
}
|
||||
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
|
||||
}
|
||||
|
||||
func getCacheAge(cacheControl string) time.Duration {
|
||||
age := defaultCacheAge
|
||||
if len(cacheControl) > 0 {
|
||||
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
|
||||
if len(match) > 0 {
|
||||
if len(match[0]) == 2 {
|
||||
maxAge := match[0][1]
|
||||
maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
|
||||
if err != nil {
|
||||
return defaultCacheAge
|
||||
}
|
||||
age = time.Duration(maxAgeInt) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
return age
|
||||
}
|
||||
|
||||
func getExpirationTime(age time.Duration) time.Time {
|
||||
return time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
|
||||
}
|
147
authority/provisioner/oidc.go
Normal file
147
authority/provisioner/oidc.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
type openIDConfiguration struct {
|
||||
Issuer string `json:"issuer"`
|
||||
JWKSetURI string `json:"jwks_uri"`
|
||||
}
|
||||
|
||||
// openIDPayload represents the fields on the id_token JWT payload.
|
||||
type openIDPayload struct {
|
||||
jose.Claims
|
||||
AtHash string `json:"at_hash"`
|
||||
AuthorizedParty string `json:"azp"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified string `json:"email_verified"`
|
||||
Hd string `json:"hd"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// OIDC represents an OAuth 2.0 OpenID Connect provider.
|
||||
type OIDC struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
ClientID string `json:"clientID"`
|
||||
ConfigurationEndpoint string `json:"configurationEndpoint"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Admins []string `json:"admins"`
|
||||
configuration openIDConfiguration
|
||||
keyStore *keyStore
|
||||
}
|
||||
|
||||
// IsAdmin returns true if the given email is in the Admins whitelist, false
|
||||
// otherwise.
|
||||
func (o *OIDC) IsAdmin(email string) bool {
|
||||
for _, e := range o.Admins {
|
||||
if e == email {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate validates and initializes the OIDC provider.
|
||||
func (o *OIDC) Validate() error {
|
||||
switch {
|
||||
case o.Name == "":
|
||||
return errors.New("name cannot be empty")
|
||||
case o.ClientID == "":
|
||||
return errors.New("clientID cannot be empty")
|
||||
case o.ConfigurationEndpoint == "":
|
||||
return errors.New("configurationEndpoint cannot be empty")
|
||||
}
|
||||
|
||||
// Decode openid-configuration endpoint
|
||||
var conf openIDConfiguration
|
||||
if err := getAndDecode(o.ConfigurationEndpoint, &conf); err != nil {
|
||||
return err
|
||||
}
|
||||
if conf.JWKSetURI == "" {
|
||||
return errors.Errorf("error parsing %s: jwks_uri cannot be empty", o.ConfigurationEndpoint)
|
||||
}
|
||||
// Get JWK key set
|
||||
keyStore, err := newKeyStore(conf.JWKSetURI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.configuration = conf
|
||||
o.keyStore = keyStore
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePayload validates the given token payload.
|
||||
//
|
||||
// TODO(mariano): avoid reply attacks validating nonce.
|
||||
func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no more
|
||||
// than a few minutes.
|
||||
if err := p.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: o.configuration.Issuer,
|
||||
Audience: jose.Audience{o.ClientID},
|
||||
}, time.Minute); err != nil {
|
||||
return errors.Wrap(err, "failed to validate payload")
|
||||
}
|
||||
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
|
||||
return errors.New("failed to validate payload: invalid azp")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authorize validates the given token.
|
||||
func (o *OIDC) Authorize(token string) ([]SignOption, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing token")
|
||||
}
|
||||
|
||||
var claims openIDPayload
|
||||
// Parse claims to get the kid
|
||||
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing claims")
|
||||
}
|
||||
|
||||
found := false
|
||||
kid := jwt.Headers[0].KeyID
|
||||
keys := o.keyStore.Get(kid)
|
||||
for _, key := range keys {
|
||||
if err := jwt.Claims(key, &claims); err == nil {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, errors.New("cannot validate token")
|
||||
}
|
||||
|
||||
if err := o.ValidatePayload(claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if o.IsAdmin(claims.Email) {
|
||||
return []SignOption{}, nil
|
||||
}
|
||||
|
||||
return []SignOption{
|
||||
emailOnlyIdentity(claims.Email),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getAndDecode(uri string, v interface{}) error {
|
||||
resp, err := http.Get(uri)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to connect to %s", uri)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
|
||||
return errors.Wrapf(err, "error reading %s", uri)
|
||||
}
|
||||
return nil
|
||||
}
|
55
authority/provisioner/sign_options.go
Normal file
55
authority/provisioner/sign_options.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
)
|
||||
|
||||
// SignOption is the interface used to collect all extra options used in the
|
||||
// Sign method.
|
||||
type SignOption interface{}
|
||||
|
||||
// CertificateValidator is the interface used to validate a X.509 certificate.
|
||||
type CertificateValidator interface {
|
||||
SignOption
|
||||
Valid(crt *x509.Certificate) error
|
||||
}
|
||||
|
||||
// CertificateRequestValidator is the interface used to validate a X.509
|
||||
// certificate request.
|
||||
type CertificateRequestValidator interface {
|
||||
SignOption
|
||||
Valid(req *x509.CertificateRequest)
|
||||
}
|
||||
|
||||
// ProfileWithOption is the interface used to add custom options to the profile
|
||||
// constructor. The options are used to modify the final certificate.
|
||||
type ProfileWithOption interface {
|
||||
SignOption
|
||||
Option() x509util.WithOption
|
||||
}
|
||||
|
||||
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
||||
// SAN provided is the given email address.
|
||||
type emailOnlyIdentity string
|
||||
|
||||
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
|
||||
switch {
|
||||
case len(req.DNSNames) > 0:
|
||||
return errors.New("certificate request cannot contain DNS names")
|
||||
case len(req.IPAddresses) > 0:
|
||||
return errors.New("certificate request cannot contain IP addresses")
|
||||
case len(req.URIs) > 0:
|
||||
return errors.New("certificate request cannot contain URIs")
|
||||
case len(req.EmailAddresses) == 0:
|
||||
return errors.New("certificate request does not contain any email address")
|
||||
case len(req.EmailAddresses) > 1:
|
||||
return errors.New("certificate request does not contain too many email addresses")
|
||||
case req.EmailAddresses[0] != string(e):
|
||||
return errors.New("certificate request does not contain the valid email address")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue