Use JWKSet to get the GCP keys.

This commit is contained in:
Mariano Cano 2019-04-17 17:38:24 -07:00
parent f794dbeb93
commit b4729cd670
2 changed files with 19 additions and 112 deletions

View file

@ -47,7 +47,7 @@ type GCP struct {
ServiceAccounts []string `json:"serviceAccounts"`
Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
certStore *keyStore
keyStore *keyStore
}
// GetID returns the provisioner unique identifier. The name should uniquely
@ -103,8 +103,8 @@ func (p *GCP) Init(config Config) error {
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize certificate store
p.certStore, err = newCertificateStore("https://www.googleapis.com/oauth2/v1/certs")
// Initialize key store
p.keyStore, err = newKeyStore("https://www.googleapis.com/oauth2/v3/certs")
if err != nil {
return err
}
@ -185,15 +185,19 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
if len(jwt.Headers) == 0 {
return nil, errors.New("error parsing token: header is missing")
}
kid := jwt.Headers[0].KeyID
cert := p.certStore.GetCertificate(kid)
if cert == nil {
return nil, errors.Errorf("failed to validate payload: cannot find certificate for kid %s", kid)
}
var found bool
var claims gcpPayload
if err = jwt.Claims(cert.PublicKey, &claims); err != nil {
return nil, errors.Wrap(err, "error parsing claims")
kid := jwt.Headers[0].KeyID
keys := p.keyStore.Get(kid)
for _, key := range keys {
if err := jwt.Claims(key, &claims); err == nil {
found = true
break
}
}
if !found {
return nil, errors.Errorf("failed to validate payload: cannot find certificate for kid %s", kid)
}
// According to "rfc7519 JSON Web Token" acceptable skew should be no

View file

@ -1,9 +1,7 @@
package provisioner
import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"math/rand"
"net/http"
"regexp"
@ -22,32 +20,13 @@ const (
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
type oauth2Certificate struct {
ID string
Certificate *x509.Certificate
}
type oauth2CertificateSet struct {
Certificates []oauth2Certificate
}
func (s oauth2CertificateSet) Get(id string) *x509.Certificate {
for _, c := range s.Certificates {
if c.ID == id {
return c.Certificate
}
}
return nil
}
type keyStore struct {
sync.RWMutex
uri string
keySet jose.JSONWebKeySet
certSet oauth2CertificateSet
timer *time.Timer
expiry time.Time
jitter time.Duration
uri string
keySet jose.JSONWebKeySet
timer *time.Timer
expiry time.Time
jitter time.Duration
}
func newKeyStore(uri string) (*keyStore, error) {
@ -66,22 +45,6 @@ func newKeyStore(uri string) (*keyStore, error) {
return ks, nil
}
func newCertificateStore(uri string) (*keyStore, error) {
certs, age, err := getOauth2Certificates(uri)
if err != nil {
return nil, err
}
ks := &keyStore{
uri: uri,
certSet: certs,
expiry: getExpirationTime(age),
jitter: getCacheJitter(age),
}
next := ks.nextReloadDuration(age)
ks.timer = time.AfterFunc(next, ks.reloadCertificates)
return ks, nil
}
func (ks *keyStore) Close() {
ks.timer.Stop()
}
@ -99,19 +62,6 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
return
}
func (ks *keyStore) GetCertificate(kid string) (cert *x509.Certificate) {
ks.RLock()
// Force reload if expiration has passed
if time.Now().After(ks.expiry) {
ks.RUnlock()
ks.reloadCertificates()
ks.RLock()
}
cert = ks.certSet.Get(kid)
ks.RUnlock()
return
}
func (ks *keyStore) reload() {
var next time.Duration
keys, age, err := getKeysFromJWKsURI(ks.uri)
@ -131,25 +81,6 @@ func (ks *keyStore) reload() {
ks.Unlock()
}
func (ks *keyStore) reloadCertificates() {
var next time.Duration
certs, age, err := getOauth2Certificates(ks.uri)
if err != nil {
next = ks.nextReloadDuration(ks.jitter / 2)
} else {
ks.Lock()
ks.certSet = certs
ks.expiry = getExpirationTime(age)
ks.jitter = getCacheJitter(age)
next = ks.nextReloadDuration(age)
ks.Unlock()
}
ks.Lock()
ks.timer.Reset(next)
ks.Unlock()
}
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
n := rand.Int63n(int64(ks.jitter))
age -= time.Duration(n)
@ -172,34 +103,6 @@ func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
}
func getOauth2Certificates(uri string) (oauth2CertificateSet, time.Duration, error) {
var certs oauth2CertificateSet
resp, err := http.Get(uri)
if err != nil {
return certs, 0, errors.Wrapf(err, "failed to connect to %s", uri)
}
defer resp.Body.Close()
m := make(map[string]string)
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return certs, 0, errors.Wrapf(err, "error reading %s", uri)
}
for k, v := range m {
block, _ := pem.Decode([]byte(v))
if block == nil || block.Type != "CERTIFICATE" {
return certs, 0, errors.Wrapf(err, "error parsing certificate %s from %s", k, uri)
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return certs, 0, errors.Wrapf(err, "error parsing certificate %s from %s", k, uri)
}
certs.Certificates = append(certs.Certificates, oauth2Certificate{
ID: k,
Certificate: cert,
})
}
return certs, getCacheAge(resp.Header.Get("cache-control")), nil
}
func getCacheAge(cacheControl string) time.Duration {
age := defaultCacheAge
if len(cacheControl) > 0 {