From a2a45f635bc0cdd53f46295dd27b84aebd362e7a Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 4 Mar 2019 17:58:20 -0800 Subject: [PATCH] Add initial implementation of an OIDC provisioner. --- authority/provisioner/claims.go | 17 +++ authority/provisioner/keystore.go | 122 +++++++++++++++++++++ authority/provisioner/oidc.go | 147 ++++++++++++++++++++++++++ authority/provisioner/sign_options.go | 55 ++++++++++ 4 files changed, 341 insertions(+) create mode 100644 authority/provisioner/claims.go create mode 100644 authority/provisioner/keystore.go create mode 100644 authority/provisioner/oidc.go create mode 100644 authority/provisioner/sign_options.go diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go new file mode 100644 index 00000000..fa9fa2fd --- /dev/null +++ b/authority/provisioner/claims.go @@ -0,0 +1,17 @@ +package provisioner + +import "time" + +// Claims so that individual provisioners can override global claims. +type Claims struct { + globalClaims *Claims + MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` + MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` + DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` + DisableRenewal *bool `json:"disableRenewal,omitempty"` +} + +// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal. +type Duration struct { + time.Duration +} diff --git a/authority/provisioner/keystore.go b/authority/provisioner/keystore.go new file mode 100644 index 00000000..7e49b8d7 --- /dev/null +++ b/authority/provisioner/keystore.go @@ -0,0 +1,122 @@ +package provisioner + +import ( + "encoding/json" + "math/rand" + "net/http" + "regexp" + "strconv" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/cli/jose" +) + +const ( + defaultCacheAge = 12 * time.Hour + defaultCacheJitter = 1 * time.Hour +) + +var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)") + +type keyStore struct { + sync.RWMutex + uri string + keys jose.JSONWebKeySet + timer *time.Timer + expiry time.Time +} + +func newKeyStore(uri string) (*keyStore, error) { + keys, age, err := getKeysFromJWKsURI(uri) + if err != nil { + return nil, err + } + ks := &keyStore{ + uri: uri, + keys: keys, + expiry: getExpirationTime(age), + } + ks.timer = time.AfterFunc(age, ks.reload) + return ks, nil +} + +func (ks *keyStore) Close() { + ks.timer.Stop() +} + +func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { + ks.RLock() + // Force reload if expiration has passed + if time.Now().After(ks.expiry) { + ks.RUnlock() + ks.reload() + ks.RLock() + } + keys = ks.keys.Key(kid) + ks.RUnlock() + return +} + +func (ks *keyStore) reload() { + var next time.Duration + keys, age, err := getKeysFromJWKsURI(ks.uri) + if err != nil { + next = ks.nextReloadDuration(defaultCacheJitter / 2) + } else { + ks.Lock() + ks.keys = keys + ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC() + ks.Unlock() + next = ks.nextReloadDuration(age) + } + + ks.Lock() + ks.timer.Reset(next) + ks.Unlock() +} + +func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { + n := rand.Int63n(int64(defaultCacheJitter)) + age -= time.Duration(n) + if age < 0 { + age = 0 + } + return age +} + +func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) { + var keys jose.JSONWebKeySet + resp, err := http.Get(uri) + if err != nil { + return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri) + } + defer resp.Body.Close() + if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { + return keys, 0, errors.Wrapf(err, "error reading %s", uri) + } + return keys, getCacheAge(resp.Header.Get("cache-control")), nil +} + +func getCacheAge(cacheControl string) time.Duration { + age := defaultCacheAge + if len(cacheControl) > 0 { + match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1) + if len(match) > 0 { + if len(match[0]) == 2 { + maxAge := match[0][1] + maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64) + if err != nil { + return defaultCacheAge + } + age = time.Duration(maxAgeInt) * time.Second + } + } + } + return age +} + +func getExpirationTime(age time.Duration) time.Time { + return time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC() +} diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go new file mode 100644 index 00000000..3c401441 --- /dev/null +++ b/authority/provisioner/oidc.go @@ -0,0 +1,147 @@ +package provisioner + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/cli/jose" +) + +type openIDConfiguration struct { + Issuer string `json:"issuer"` + JWKSetURI string `json:"jwks_uri"` +} + +// openIDPayload represents the fields on the id_token JWT payload. +type openIDPayload struct { + jose.Claims + AtHash string `json:"at_hash"` + AuthorizedParty string `json:"azp"` + Email string `json:"email"` + EmailVerified string `json:"email_verified"` + Hd string `json:"hd"` + Nonce string `json:"nonce"` +} + +// OIDC represents an OAuth 2.0 OpenID Connect provider. +type OIDC struct { + Type string `json:"type"` + Name string `json:"name"` + ClientID string `json:"clientID"` + ConfigurationEndpoint string `json:"configurationEndpoint"` + Claims *Claims `json:"claims,omitempty"` + Admins []string `json:"admins"` + configuration openIDConfiguration + keyStore *keyStore +} + +// IsAdmin returns true if the given email is in the Admins whitelist, false +// otherwise. +func (o *OIDC) IsAdmin(email string) bool { + for _, e := range o.Admins { + if e == email { + return true + } + } + return false +} + +// Validate validates and initializes the OIDC provider. +func (o *OIDC) Validate() error { + switch { + case o.Name == "": + return errors.New("name cannot be empty") + case o.ClientID == "": + return errors.New("clientID cannot be empty") + case o.ConfigurationEndpoint == "": + return errors.New("configurationEndpoint cannot be empty") + } + + // Decode openid-configuration endpoint + var conf openIDConfiguration + if err := getAndDecode(o.ConfigurationEndpoint, &conf); err != nil { + return err + } + if conf.JWKSetURI == "" { + return errors.Errorf("error parsing %s: jwks_uri cannot be empty", o.ConfigurationEndpoint) + } + // Get JWK key set + keyStore, err := newKeyStore(conf.JWKSetURI) + if err != nil { + return err + } + o.configuration = conf + o.keyStore = keyStore + return nil +} + +// ValidatePayload validates the given token payload. +// +// TODO(mariano): avoid reply attacks validating nonce. +func (o *OIDC) ValidatePayload(p openIDPayload) error { + // According to "rfc7519 JSON Web Token" acceptable skew should be no more + // than a few minutes. + if err := p.ValidateWithLeeway(jose.Expected{ + Issuer: o.configuration.Issuer, + Audience: jose.Audience{o.ClientID}, + }, time.Minute); err != nil { + return errors.Wrap(err, "failed to validate payload") + } + if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID { + return errors.New("failed to validate payload: invalid azp") + } + return nil +} + +// Authorize validates the given token. +func (o *OIDC) Authorize(token string) ([]SignOption, error) { + jwt, err := jose.ParseSigned(token) + if err != nil { + return nil, errors.Wrapf(err, "error parsing token") + } + + var claims openIDPayload + // Parse claims to get the kid + if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, errors.Wrap(err, "error parsing claims") + } + + found := false + kid := jwt.Headers[0].KeyID + keys := o.keyStore.Get(kid) + for _, key := range keys { + if err := jwt.Claims(key, &claims); err == nil { + found = true + break + } + } + if !found { + return nil, errors.New("cannot validate token") + } + + if err := o.ValidatePayload(claims); err != nil { + return nil, err + } + + if o.IsAdmin(claims.Email) { + return []SignOption{}, nil + } + + return []SignOption{ + emailOnlyIdentity(claims.Email), + }, nil +} + +func getAndDecode(uri string, v interface{}) error { + resp, err := http.Get(uri) + if err != nil { + return errors.Wrapf(err, "failed to connect to %s", uri) + } + defer resp.Body.Close() + if err := json.NewDecoder(resp.Body).Decode(v); err != nil { + return errors.Wrapf(err, "error reading %s", uri) + } + return nil +} diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go new file mode 100644 index 00000000..7208d751 --- /dev/null +++ b/authority/provisioner/sign_options.go @@ -0,0 +1,55 @@ +package provisioner + +import ( + "crypto/x509" + + "github.com/pkg/errors" + "github.com/smallstep/cli/crypto/x509util" +) + +// SignOption is the interface used to collect all extra options used in the +// Sign method. +type SignOption interface{} + +// CertificateValidator is the interface used to validate a X.509 certificate. +type CertificateValidator interface { + SignOption + Valid(crt *x509.Certificate) error +} + +// CertificateRequestValidator is the interface used to validate a X.509 +// certificate request. +type CertificateRequestValidator interface { + SignOption + Valid(req *x509.CertificateRequest) +} + +// ProfileWithOption is the interface used to add custom options to the profile +// constructor. The options are used to modify the final certificate. +type ProfileWithOption interface { + SignOption + Option() x509util.WithOption +} + +// emailOnlyIdentity is a CertificateRequestValidator that checks that the only +// SAN provided is the given email address. +type emailOnlyIdentity string + +func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error { + switch { + case len(req.DNSNames) > 0: + return errors.New("certificate request cannot contain DNS names") + case len(req.IPAddresses) > 0: + return errors.New("certificate request cannot contain IP addresses") + case len(req.URIs) > 0: + return errors.New("certificate request cannot contain URIs") + case len(req.EmailAddresses) == 0: + return errors.New("certificate request does not contain any email address") + case len(req.EmailAddresses) > 1: + return errors.New("certificate request does not contain too many email addresses") + case req.EmailAddresses[0] != string(e): + return errors.New("certificate request does not contain the valid email address") + default: + return nil + } +}