2019-03-05 01:58:20 +00:00
|
|
|
package provisioner
|
|
|
|
|
|
|
|
import (
|
2019-04-18 00:28:21 +00:00
|
|
|
"crypto/x509"
|
2019-03-05 01:58:20 +00:00
|
|
|
"encoding/json"
|
2019-04-18 00:28:21 +00:00
|
|
|
"encoding/pem"
|
2019-03-05 01:58:20 +00:00
|
|
|
"math/rand"
|
|
|
|
"net/http"
|
|
|
|
"regexp"
|
|
|
|
"strconv"
|
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/smallstep/cli/jose"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
defaultCacheAge = 12 * time.Hour
|
|
|
|
defaultCacheJitter = 1 * time.Hour
|
|
|
|
)
|
|
|
|
|
|
|
|
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
|
|
|
|
2019-04-18 00:28:21 +00:00
|
|
|
type oauth2Certificate struct {
|
|
|
|
ID string
|
|
|
|
Certificate *x509.Certificate
|
|
|
|
}
|
|
|
|
|
|
|
|
type oauth2CertificateSet struct {
|
|
|
|
Certificates []oauth2Certificate
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s oauth2CertificateSet) Get(id string) *x509.Certificate {
|
|
|
|
for _, c := range s.Certificates {
|
|
|
|
if c.ID == id {
|
|
|
|
return c.Certificate
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2019-03-05 01:58:20 +00:00
|
|
|
type keyStore struct {
|
|
|
|
sync.RWMutex
|
2019-04-18 00:28:21 +00:00
|
|
|
uri string
|
|
|
|
keySet jose.JSONWebKeySet
|
|
|
|
certSet oauth2CertificateSet
|
|
|
|
timer *time.Timer
|
|
|
|
expiry time.Time
|
|
|
|
jitter time.Duration
|
2019-03-05 01:58:20 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func newKeyStore(uri string) (*keyStore, error) {
|
|
|
|
keys, age, err := getKeysFromJWKsURI(uri)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
ks := &keyStore{
|
|
|
|
uri: uri,
|
2019-03-08 20:19:44 +00:00
|
|
|
keySet: keys,
|
2019-03-05 01:58:20 +00:00
|
|
|
expiry: getExpirationTime(age),
|
2019-03-08 23:08:18 +00:00
|
|
|
jitter: getCacheJitter(age),
|
2019-03-05 01:58:20 +00:00
|
|
|
}
|
2019-03-08 23:08:18 +00:00
|
|
|
next := ks.nextReloadDuration(age)
|
|
|
|
ks.timer = time.AfterFunc(next, ks.reload)
|
2019-03-05 01:58:20 +00:00
|
|
|
return ks, nil
|
|
|
|
}
|
|
|
|
|
2019-04-18 00:28:21 +00:00
|
|
|
func newCertificateStore(uri string) (*keyStore, error) {
|
|
|
|
certs, age, err := getOauth2Certificates(uri)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
ks := &keyStore{
|
|
|
|
uri: uri,
|
|
|
|
certSet: certs,
|
|
|
|
expiry: getExpirationTime(age),
|
|
|
|
jitter: getCacheJitter(age),
|
|
|
|
}
|
|
|
|
next := ks.nextReloadDuration(age)
|
|
|
|
ks.timer = time.AfterFunc(next, ks.reloadCertificates)
|
|
|
|
return ks, nil
|
|
|
|
}
|
|
|
|
|
2019-03-05 01:58:20 +00:00
|
|
|
func (ks *keyStore) Close() {
|
|
|
|
ks.timer.Stop()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
|
|
|
ks.RLock()
|
|
|
|
// Force reload if expiration has passed
|
|
|
|
if time.Now().After(ks.expiry) {
|
|
|
|
ks.RUnlock()
|
|
|
|
ks.reload()
|
|
|
|
ks.RLock()
|
|
|
|
}
|
2019-03-08 20:19:44 +00:00
|
|
|
keys = ks.keySet.Key(kid)
|
2019-03-05 01:58:20 +00:00
|
|
|
ks.RUnlock()
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2019-04-18 00:28:21 +00:00
|
|
|
func (ks *keyStore) GetCertificate(kid string) (cert *x509.Certificate) {
|
|
|
|
ks.RLock()
|
|
|
|
// Force reload if expiration has passed
|
|
|
|
if time.Now().After(ks.expiry) {
|
|
|
|
ks.RUnlock()
|
|
|
|
ks.reloadCertificates()
|
|
|
|
ks.RLock()
|
|
|
|
}
|
|
|
|
cert = ks.certSet.Get(kid)
|
|
|
|
ks.RUnlock()
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2019-03-05 01:58:20 +00:00
|
|
|
func (ks *keyStore) reload() {
|
|
|
|
var next time.Duration
|
|
|
|
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
|
|
|
if err != nil {
|
2019-03-08 23:08:18 +00:00
|
|
|
next = ks.nextReloadDuration(ks.jitter / 2)
|
2019-03-05 01:58:20 +00:00
|
|
|
} else {
|
|
|
|
ks.Lock()
|
2019-03-08 20:19:44 +00:00
|
|
|
ks.keySet = keys
|
2019-03-08 23:08:18 +00:00
|
|
|
ks.expiry = getExpirationTime(age)
|
|
|
|
ks.jitter = getCacheJitter(age)
|
2019-03-05 01:58:20 +00:00
|
|
|
next = ks.nextReloadDuration(age)
|
2019-03-08 23:08:18 +00:00
|
|
|
ks.Unlock()
|
2019-03-05 01:58:20 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
ks.Lock()
|
|
|
|
ks.timer.Reset(next)
|
|
|
|
ks.Unlock()
|
|
|
|
}
|
|
|
|
|
2019-04-18 00:28:21 +00:00
|
|
|
func (ks *keyStore) reloadCertificates() {
|
|
|
|
var next time.Duration
|
|
|
|
certs, age, err := getOauth2Certificates(ks.uri)
|
|
|
|
if err != nil {
|
|
|
|
next = ks.nextReloadDuration(ks.jitter / 2)
|
|
|
|
} else {
|
|
|
|
ks.Lock()
|
|
|
|
ks.certSet = certs
|
|
|
|
ks.expiry = getExpirationTime(age)
|
|
|
|
ks.jitter = getCacheJitter(age)
|
|
|
|
next = ks.nextReloadDuration(age)
|
|
|
|
ks.Unlock()
|
|
|
|
}
|
|
|
|
|
|
|
|
ks.Lock()
|
|
|
|
ks.timer.Reset(next)
|
|
|
|
ks.Unlock()
|
|
|
|
}
|
|
|
|
|
2019-03-05 01:58:20 +00:00
|
|
|
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
2019-03-08 23:08:18 +00:00
|
|
|
n := rand.Int63n(int64(ks.jitter))
|
2019-03-05 01:58:20 +00:00
|
|
|
age -= time.Duration(n)
|
|
|
|
if age < 0 {
|
|
|
|
age = 0
|
|
|
|
}
|
|
|
|
return age
|
|
|
|
}
|
|
|
|
|
|
|
|
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
|
|
|
var keys jose.JSONWebKeySet
|
|
|
|
resp, err := http.Get(uri)
|
|
|
|
if err != nil {
|
|
|
|
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
|
|
|
|
return keys, 0, errors.Wrapf(err, "error reading %s", uri)
|
|
|
|
}
|
|
|
|
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
|
|
|
|
}
|
|
|
|
|
2019-04-18 00:28:21 +00:00
|
|
|
func getOauth2Certificates(uri string) (oauth2CertificateSet, time.Duration, error) {
|
|
|
|
var certs oauth2CertificateSet
|
|
|
|
resp, err := http.Get(uri)
|
|
|
|
if err != nil {
|
|
|
|
return certs, 0, errors.Wrapf(err, "failed to connect to %s", uri)
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
m := make(map[string]string)
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
|
|
|
return certs, 0, errors.Wrapf(err, "error reading %s", uri)
|
|
|
|
}
|
|
|
|
for k, v := range m {
|
|
|
|
block, _ := pem.Decode([]byte(v))
|
|
|
|
if block == nil || block.Type != "CERTIFICATE" {
|
|
|
|
return certs, 0, errors.Wrapf(err, "error parsing certificate %s from %s", k, uri)
|
|
|
|
}
|
|
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
|
|
if err != nil {
|
|
|
|
return certs, 0, errors.Wrapf(err, "error parsing certificate %s from %s", k, uri)
|
|
|
|
}
|
|
|
|
certs.Certificates = append(certs.Certificates, oauth2Certificate{
|
|
|
|
ID: k,
|
|
|
|
Certificate: cert,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
return certs, getCacheAge(resp.Header.Get("cache-control")), nil
|
|
|
|
}
|
|
|
|
|
2019-03-05 01:58:20 +00:00
|
|
|
func getCacheAge(cacheControl string) time.Duration {
|
|
|
|
age := defaultCacheAge
|
|
|
|
if len(cacheControl) > 0 {
|
|
|
|
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
|
|
|
|
if len(match) > 0 {
|
|
|
|
if len(match[0]) == 2 {
|
|
|
|
maxAge := match[0][1]
|
|
|
|
maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
|
|
|
|
if err != nil {
|
|
|
|
return defaultCacheAge
|
|
|
|
}
|
|
|
|
age = time.Duration(maxAgeInt) * time.Second
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return age
|
|
|
|
}
|
|
|
|
|
2019-03-08 23:08:18 +00:00
|
|
|
func getCacheJitter(age time.Duration) time.Duration {
|
|
|
|
switch {
|
|
|
|
case age > time.Hour:
|
|
|
|
return defaultCacheJitter
|
|
|
|
default:
|
|
|
|
return age / 3
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-03-05 01:58:20 +00:00
|
|
|
func getExpirationTime(age time.Duration) time.Time {
|
2019-03-12 01:13:20 +00:00
|
|
|
return time.Now().Truncate(time.Second).Add(age)
|
2019-03-05 01:58:20 +00:00
|
|
|
}
|