Add initial implementation of an OIDC provisioner.

This commit is contained in:
Mariano Cano 2019-03-04 17:58:20 -08:00
parent 98b3d971f6
commit a2a45f635b
4 changed files with 341 additions and 0 deletions

View file

@ -0,0 +1,17 @@
package provisioner
import "time"
// Claims so that individual provisioners can override global claims.
type Claims struct {
globalClaims *Claims
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"`
}
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
type Duration struct {
time.Duration
}

View file

@ -0,0 +1,122 @@
package provisioner
import (
"encoding/json"
"math/rand"
"net/http"
"regexp"
"strconv"
"sync"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
)
const (
defaultCacheAge = 12 * time.Hour
defaultCacheJitter = 1 * time.Hour
)
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
type keyStore struct {
sync.RWMutex
uri string
keys jose.JSONWebKeySet
timer *time.Timer
expiry time.Time
}
func newKeyStore(uri string) (*keyStore, error) {
keys, age, err := getKeysFromJWKsURI(uri)
if err != nil {
return nil, err
}
ks := &keyStore{
uri: uri,
keys: keys,
expiry: getExpirationTime(age),
}
ks.timer = time.AfterFunc(age, ks.reload)
return ks, nil
}
func (ks *keyStore) Close() {
ks.timer.Stop()
}
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
ks.RLock()
// Force reload if expiration has passed
if time.Now().After(ks.expiry) {
ks.RUnlock()
ks.reload()
ks.RLock()
}
keys = ks.keys.Key(kid)
ks.RUnlock()
return
}
func (ks *keyStore) reload() {
var next time.Duration
keys, age, err := getKeysFromJWKsURI(ks.uri)
if err != nil {
next = ks.nextReloadDuration(defaultCacheJitter / 2)
} else {
ks.Lock()
ks.keys = keys
ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
ks.Unlock()
next = ks.nextReloadDuration(age)
}
ks.Lock()
ks.timer.Reset(next)
ks.Unlock()
}
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
n := rand.Int63n(int64(defaultCacheJitter))
age -= time.Duration(n)
if age < 0 {
age = 0
}
return age
}
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
var keys jose.JSONWebKeySet
resp, err := http.Get(uri)
if err != nil {
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
return keys, 0, errors.Wrapf(err, "error reading %s", uri)
}
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
}
func getCacheAge(cacheControl string) time.Duration {
age := defaultCacheAge
if len(cacheControl) > 0 {
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
if len(match) > 0 {
if len(match[0]) == 2 {
maxAge := match[0][1]
maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
if err != nil {
return defaultCacheAge
}
age = time.Duration(maxAgeInt) * time.Second
}
}
}
return age
}
func getExpirationTime(age time.Duration) time.Time {
return time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
}

View file

@ -0,0 +1,147 @@
package provisioner
import (
"encoding/json"
"net/http"
"time"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
)
type openIDConfiguration struct {
Issuer string `json:"issuer"`
JWKSetURI string `json:"jwks_uri"`
}
// openIDPayload represents the fields on the id_token JWT payload.
type openIDPayload struct {
jose.Claims
AtHash string `json:"at_hash"`
AuthorizedParty string `json:"azp"`
Email string `json:"email"`
EmailVerified string `json:"email_verified"`
Hd string `json:"hd"`
Nonce string `json:"nonce"`
}
// OIDC represents an OAuth 2.0 OpenID Connect provider.
type OIDC struct {
Type string `json:"type"`
Name string `json:"name"`
ClientID string `json:"clientID"`
ConfigurationEndpoint string `json:"configurationEndpoint"`
Claims *Claims `json:"claims,omitempty"`
Admins []string `json:"admins"`
configuration openIDConfiguration
keyStore *keyStore
}
// IsAdmin returns true if the given email is in the Admins whitelist, false
// otherwise.
func (o *OIDC) IsAdmin(email string) bool {
for _, e := range o.Admins {
if e == email {
return true
}
}
return false
}
// Validate validates and initializes the OIDC provider.
func (o *OIDC) Validate() error {
switch {
case o.Name == "":
return errors.New("name cannot be empty")
case o.ClientID == "":
return errors.New("clientID cannot be empty")
case o.ConfigurationEndpoint == "":
return errors.New("configurationEndpoint cannot be empty")
}
// Decode openid-configuration endpoint
var conf openIDConfiguration
if err := getAndDecode(o.ConfigurationEndpoint, &conf); err != nil {
return err
}
if conf.JWKSetURI == "" {
return errors.Errorf("error parsing %s: jwks_uri cannot be empty", o.ConfigurationEndpoint)
}
// Get JWK key set
keyStore, err := newKeyStore(conf.JWKSetURI)
if err != nil {
return err
}
o.configuration = conf
o.keyStore = keyStore
return nil
}
// ValidatePayload validates the given token payload.
//
// TODO(mariano): avoid reply attacks validating nonce.
func (o *OIDC) ValidatePayload(p openIDPayload) error {
// According to "rfc7519 JSON Web Token" acceptable skew should be no more
// than a few minutes.
if err := p.ValidateWithLeeway(jose.Expected{
Issuer: o.configuration.Issuer,
Audience: jose.Audience{o.ClientID},
}, time.Minute); err != nil {
return errors.Wrap(err, "failed to validate payload")
}
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
return errors.New("failed to validate payload: invalid azp")
}
return nil
}
// Authorize validates the given token.
func (o *OIDC) Authorize(token string) ([]SignOption, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrapf(err, "error parsing token")
}
var claims openIDPayload
// Parse claims to get the kid
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims")
}
found := false
kid := jwt.Headers[0].KeyID
keys := o.keyStore.Get(kid)
for _, key := range keys {
if err := jwt.Claims(key, &claims); err == nil {
found = true
break
}
}
if !found {
return nil, errors.New("cannot validate token")
}
if err := o.ValidatePayload(claims); err != nil {
return nil, err
}
if o.IsAdmin(claims.Email) {
return []SignOption{}, nil
}
return []SignOption{
emailOnlyIdentity(claims.Email),
}, nil
}
func getAndDecode(uri string, v interface{}) error {
resp, err := http.Get(uri)
if err != nil {
return errors.Wrapf(err, "failed to connect to %s", uri)
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
return errors.Wrapf(err, "error reading %s", uri)
}
return nil
}

View file

@ -0,0 +1,55 @@
package provisioner
import (
"crypto/x509"
"github.com/pkg/errors"
"github.com/smallstep/cli/crypto/x509util"
)
// SignOption is the interface used to collect all extra options used in the
// Sign method.
type SignOption interface{}
// CertificateValidator is the interface used to validate a X.509 certificate.
type CertificateValidator interface {
SignOption
Valid(crt *x509.Certificate) error
}
// CertificateRequestValidator is the interface used to validate a X.509
// certificate request.
type CertificateRequestValidator interface {
SignOption
Valid(req *x509.CertificateRequest)
}
// ProfileWithOption is the interface used to add custom options to the profile
// constructor. The options are used to modify the final certificate.
type ProfileWithOption interface {
SignOption
Option() x509util.WithOption
}
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
// SAN provided is the given email address.
type emailOnlyIdentity string
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
switch {
case len(req.DNSNames) > 0:
return errors.New("certificate request cannot contain DNS names")
case len(req.IPAddresses) > 0:
return errors.New("certificate request cannot contain IP addresses")
case len(req.URIs) > 0:
return errors.New("certificate request cannot contain URIs")
case len(req.EmailAddresses) == 0:
return errors.New("certificate request does not contain any email address")
case len(req.EmailAddresses) > 1:
return errors.New("certificate request does not contain too many email addresses")
case req.EmailAddresses[0] != string(e):
return errors.New("certificate request does not contain the valid email address")
default:
return nil
}
}