forked from TrueCloudLab/certificates
Add initial implementation of an OIDC provisioner.
This commit is contained in:
parent
98b3d971f6
commit
a2a45f635b
4 changed files with 341 additions and 0 deletions
17
authority/provisioner/claims.go
Normal file
17
authority/provisioner/claims.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// Claims so that individual provisioners can override global claims.
|
||||||
|
type Claims struct {
|
||||||
|
globalClaims *Claims
|
||||||
|
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
||||||
|
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
||||||
|
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
||||||
|
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
|
||||||
|
type Duration struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
122
authority/provisioner/keystore.go
Normal file
122
authority/provisioner/keystore.go
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultCacheAge = 12 * time.Hour
|
||||||
|
defaultCacheJitter = 1 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
||||||
|
|
||||||
|
type keyStore struct {
|
||||||
|
sync.RWMutex
|
||||||
|
uri string
|
||||||
|
keys jose.JSONWebKeySet
|
||||||
|
timer *time.Timer
|
||||||
|
expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newKeyStore(uri string) (*keyStore, error) {
|
||||||
|
keys, age, err := getKeysFromJWKsURI(uri)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ks := &keyStore{
|
||||||
|
uri: uri,
|
||||||
|
keys: keys,
|
||||||
|
expiry: getExpirationTime(age),
|
||||||
|
}
|
||||||
|
ks.timer = time.AfterFunc(age, ks.reload)
|
||||||
|
return ks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *keyStore) Close() {
|
||||||
|
ks.timer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
||||||
|
ks.RLock()
|
||||||
|
// Force reload if expiration has passed
|
||||||
|
if time.Now().After(ks.expiry) {
|
||||||
|
ks.RUnlock()
|
||||||
|
ks.reload()
|
||||||
|
ks.RLock()
|
||||||
|
}
|
||||||
|
keys = ks.keys.Key(kid)
|
||||||
|
ks.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *keyStore) reload() {
|
||||||
|
var next time.Duration
|
||||||
|
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
||||||
|
if err != nil {
|
||||||
|
next = ks.nextReloadDuration(defaultCacheJitter / 2)
|
||||||
|
} else {
|
||||||
|
ks.Lock()
|
||||||
|
ks.keys = keys
|
||||||
|
ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
|
||||||
|
ks.Unlock()
|
||||||
|
next = ks.nextReloadDuration(age)
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.Lock()
|
||||||
|
ks.timer.Reset(next)
|
||||||
|
ks.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||||
|
n := rand.Int63n(int64(defaultCacheJitter))
|
||||||
|
age -= time.Duration(n)
|
||||||
|
if age < 0 {
|
||||||
|
age = 0
|
||||||
|
}
|
||||||
|
return age
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||||
|
var keys jose.JSONWebKeySet
|
||||||
|
resp, err := http.Get(uri)
|
||||||
|
if err != nil {
|
||||||
|
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
|
||||||
|
return keys, 0, errors.Wrapf(err, "error reading %s", uri)
|
||||||
|
}
|
||||||
|
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCacheAge(cacheControl string) time.Duration {
|
||||||
|
age := defaultCacheAge
|
||||||
|
if len(cacheControl) > 0 {
|
||||||
|
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
|
||||||
|
if len(match) > 0 {
|
||||||
|
if len(match[0]) == 2 {
|
||||||
|
maxAge := match[0][1]
|
||||||
|
maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return defaultCacheAge
|
||||||
|
}
|
||||||
|
age = time.Duration(maxAgeInt) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return age
|
||||||
|
}
|
||||||
|
|
||||||
|
func getExpirationTime(age time.Duration) time.Time {
|
||||||
|
return time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
|
||||||
|
}
|
147
authority/provisioner/oidc.go
Normal file
147
authority/provisioner/oidc.go
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openIDConfiguration struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
JWKSetURI string `json:"jwks_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// openIDPayload represents the fields on the id_token JWT payload.
|
||||||
|
type openIDPayload struct {
|
||||||
|
jose.Claims
|
||||||
|
AtHash string `json:"at_hash"`
|
||||||
|
AuthorizedParty string `json:"azp"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
EmailVerified string `json:"email_verified"`
|
||||||
|
Hd string `json:"hd"`
|
||||||
|
Nonce string `json:"nonce"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDC represents an OAuth 2.0 OpenID Connect provider.
|
||||||
|
type OIDC struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
ClientID string `json:"clientID"`
|
||||||
|
ConfigurationEndpoint string `json:"configurationEndpoint"`
|
||||||
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
|
Admins []string `json:"admins"`
|
||||||
|
configuration openIDConfiguration
|
||||||
|
keyStore *keyStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAdmin returns true if the given email is in the Admins whitelist, false
|
||||||
|
// otherwise.
|
||||||
|
func (o *OIDC) IsAdmin(email string) bool {
|
||||||
|
for _, e := range o.Admins {
|
||||||
|
if e == email {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates and initializes the OIDC provider.
|
||||||
|
func (o *OIDC) Validate() error {
|
||||||
|
switch {
|
||||||
|
case o.Name == "":
|
||||||
|
return errors.New("name cannot be empty")
|
||||||
|
case o.ClientID == "":
|
||||||
|
return errors.New("clientID cannot be empty")
|
||||||
|
case o.ConfigurationEndpoint == "":
|
||||||
|
return errors.New("configurationEndpoint cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode openid-configuration endpoint
|
||||||
|
var conf openIDConfiguration
|
||||||
|
if err := getAndDecode(o.ConfigurationEndpoint, &conf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if conf.JWKSetURI == "" {
|
||||||
|
return errors.Errorf("error parsing %s: jwks_uri cannot be empty", o.ConfigurationEndpoint)
|
||||||
|
}
|
||||||
|
// Get JWK key set
|
||||||
|
keyStore, err := newKeyStore(conf.JWKSetURI)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
o.configuration = conf
|
||||||
|
o.keyStore = keyStore
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidatePayload validates the given token payload.
|
||||||
|
//
|
||||||
|
// TODO(mariano): avoid reply attacks validating nonce.
|
||||||
|
func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
||||||
|
// According to "rfc7519 JSON Web Token" acceptable skew should be no more
|
||||||
|
// than a few minutes.
|
||||||
|
if err := p.ValidateWithLeeway(jose.Expected{
|
||||||
|
Issuer: o.configuration.Issuer,
|
||||||
|
Audience: jose.Audience{o.ClientID},
|
||||||
|
}, time.Minute); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to validate payload")
|
||||||
|
}
|
||||||
|
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
|
||||||
|
return errors.New("failed to validate payload: invalid azp")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorize validates the given token.
|
||||||
|
func (o *OIDC) Authorize(token string) ([]SignOption, error) {
|
||||||
|
jwt, err := jose.ParseSigned(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error parsing token")
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims openIDPayload
|
||||||
|
// Parse claims to get the kid
|
||||||
|
if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error parsing claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
kid := jwt.Headers[0].KeyID
|
||||||
|
keys := o.keyStore.Get(kid)
|
||||||
|
for _, key := range keys {
|
||||||
|
if err := jwt.Claims(key, &claims); err == nil {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, errors.New("cannot validate token")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := o.ValidatePayload(claims); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.IsAdmin(claims.Email) {
|
||||||
|
return []SignOption{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return []SignOption{
|
||||||
|
emailOnlyIdentity(claims.Email),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAndDecode(uri string, v interface{}) error {
|
||||||
|
resp, err := http.Get(uri)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrapf(err, "failed to connect to %s", uri)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
|
||||||
|
return errors.Wrapf(err, "error reading %s", uri)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
55
authority/provisioner/sign_options.go
Normal file
55
authority/provisioner/sign_options.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/crypto/x509util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SignOption is the interface used to collect all extra options used in the
|
||||||
|
// Sign method.
|
||||||
|
type SignOption interface{}
|
||||||
|
|
||||||
|
// CertificateValidator is the interface used to validate a X.509 certificate.
|
||||||
|
type CertificateValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(crt *x509.Certificate) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// CertificateRequestValidator is the interface used to validate a X.509
|
||||||
|
// certificate request.
|
||||||
|
type CertificateRequestValidator interface {
|
||||||
|
SignOption
|
||||||
|
Valid(req *x509.CertificateRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProfileWithOption is the interface used to add custom options to the profile
|
||||||
|
// constructor. The options are used to modify the final certificate.
|
||||||
|
type ProfileWithOption interface {
|
||||||
|
SignOption
|
||||||
|
Option() x509util.WithOption
|
||||||
|
}
|
||||||
|
|
||||||
|
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
|
||||||
|
// SAN provided is the given email address.
|
||||||
|
type emailOnlyIdentity string
|
||||||
|
|
||||||
|
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
|
||||||
|
switch {
|
||||||
|
case len(req.DNSNames) > 0:
|
||||||
|
return errors.New("certificate request cannot contain DNS names")
|
||||||
|
case len(req.IPAddresses) > 0:
|
||||||
|
return errors.New("certificate request cannot contain IP addresses")
|
||||||
|
case len(req.URIs) > 0:
|
||||||
|
return errors.New("certificate request cannot contain URIs")
|
||||||
|
case len(req.EmailAddresses) == 0:
|
||||||
|
return errors.New("certificate request does not contain any email address")
|
||||||
|
case len(req.EmailAddresses) > 1:
|
||||||
|
return errors.New("certificate request does not contain too many email addresses")
|
||||||
|
case req.EmailAddresses[0] != string(e):
|
||||||
|
return errors.New("certificate request does not contain the valid email address")
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue