Use JWKSet to get the GCP keys.
This commit is contained in:
parent
f794dbeb93
commit
b4729cd670
2 changed files with 19 additions and 112 deletions
|
@ -47,7 +47,7 @@ type GCP struct {
|
||||||
ServiceAccounts []string `json:"serviceAccounts"`
|
ServiceAccounts []string `json:"serviceAccounts"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
claimer *Claimer
|
claimer *Claimer
|
||||||
certStore *keyStore
|
keyStore *keyStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier. The name should uniquely
|
// GetID returns the provisioner unique identifier. The name should uniquely
|
||||||
|
@ -103,8 +103,8 @@ func (p *GCP) Init(config Config) error {
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Initialize certificate store
|
// Initialize key store
|
||||||
p.certStore, err = newCertificateStore("https://www.googleapis.com/oauth2/v1/certs")
|
p.keyStore, err = newKeyStore("https://www.googleapis.com/oauth2/v3/certs")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -185,15 +185,19 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
||||||
if len(jwt.Headers) == 0 {
|
if len(jwt.Headers) == 0 {
|
||||||
return nil, errors.New("error parsing token: header is missing")
|
return nil, errors.New("error parsing token: header is missing")
|
||||||
}
|
}
|
||||||
kid := jwt.Headers[0].KeyID
|
|
||||||
cert := p.certStore.GetCertificate(kid)
|
|
||||||
if cert == nil {
|
|
||||||
return nil, errors.Errorf("failed to validate payload: cannot find certificate for kid %s", kid)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
var found bool
|
||||||
var claims gcpPayload
|
var claims gcpPayload
|
||||||
if err = jwt.Claims(cert.PublicKey, &claims); err != nil {
|
kid := jwt.Headers[0].KeyID
|
||||||
return nil, errors.Wrap(err, "error parsing claims")
|
keys := p.keyStore.Get(kid)
|
||||||
|
for _, key := range keys {
|
||||||
|
if err := jwt.Claims(key, &claims); err == nil {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, errors.Errorf("failed to validate payload: cannot find certificate for kid %s", kid)
|
||||||
}
|
}
|
||||||
|
|
||||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -22,32 +20,13 @@ const (
|
||||||
|
|
||||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
||||||
|
|
||||||
type oauth2Certificate struct {
|
|
||||||
ID string
|
|
||||||
Certificate *x509.Certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
type oauth2CertificateSet struct {
|
|
||||||
Certificates []oauth2Certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s oauth2CertificateSet) Get(id string) *x509.Certificate {
|
|
||||||
for _, c := range s.Certificates {
|
|
||||||
if c.ID == id {
|
|
||||||
return c.Certificate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type keyStore struct {
|
type keyStore struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
uri string
|
uri string
|
||||||
keySet jose.JSONWebKeySet
|
keySet jose.JSONWebKeySet
|
||||||
certSet oauth2CertificateSet
|
timer *time.Timer
|
||||||
timer *time.Timer
|
expiry time.Time
|
||||||
expiry time.Time
|
jitter time.Duration
|
||||||
jitter time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newKeyStore(uri string) (*keyStore, error) {
|
func newKeyStore(uri string) (*keyStore, error) {
|
||||||
|
@ -66,22 +45,6 @@ func newKeyStore(uri string) (*keyStore, error) {
|
||||||
return ks, nil
|
return ks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCertificateStore(uri string) (*keyStore, error) {
|
|
||||||
certs, age, err := getOauth2Certificates(uri)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ks := &keyStore{
|
|
||||||
uri: uri,
|
|
||||||
certSet: certs,
|
|
||||||
expiry: getExpirationTime(age),
|
|
||||||
jitter: getCacheJitter(age),
|
|
||||||
}
|
|
||||||
next := ks.nextReloadDuration(age)
|
|
||||||
ks.timer = time.AfterFunc(next, ks.reloadCertificates)
|
|
||||||
return ks, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ks *keyStore) Close() {
|
func (ks *keyStore) Close() {
|
||||||
ks.timer.Stop()
|
ks.timer.Stop()
|
||||||
}
|
}
|
||||||
|
@ -99,19 +62,6 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ks *keyStore) GetCertificate(kid string) (cert *x509.Certificate) {
|
|
||||||
ks.RLock()
|
|
||||||
// Force reload if expiration has passed
|
|
||||||
if time.Now().After(ks.expiry) {
|
|
||||||
ks.RUnlock()
|
|
||||||
ks.reloadCertificates()
|
|
||||||
ks.RLock()
|
|
||||||
}
|
|
||||||
cert = ks.certSet.Get(kid)
|
|
||||||
ks.RUnlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ks *keyStore) reload() {
|
func (ks *keyStore) reload() {
|
||||||
var next time.Duration
|
var next time.Duration
|
||||||
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
||||||
|
@ -131,25 +81,6 @@ func (ks *keyStore) reload() {
|
||||||
ks.Unlock()
|
ks.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ks *keyStore) reloadCertificates() {
|
|
||||||
var next time.Duration
|
|
||||||
certs, age, err := getOauth2Certificates(ks.uri)
|
|
||||||
if err != nil {
|
|
||||||
next = ks.nextReloadDuration(ks.jitter / 2)
|
|
||||||
} else {
|
|
||||||
ks.Lock()
|
|
||||||
ks.certSet = certs
|
|
||||||
ks.expiry = getExpirationTime(age)
|
|
||||||
ks.jitter = getCacheJitter(age)
|
|
||||||
next = ks.nextReloadDuration(age)
|
|
||||||
ks.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
ks.Lock()
|
|
||||||
ks.timer.Reset(next)
|
|
||||||
ks.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||||
n := rand.Int63n(int64(ks.jitter))
|
n := rand.Int63n(int64(ks.jitter))
|
||||||
age -= time.Duration(n)
|
age -= time.Duration(n)
|
||||||
|
@ -172,34 +103,6 @@ func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||||
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
|
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOauth2Certificates(uri string) (oauth2CertificateSet, time.Duration, error) {
|
|
||||||
var certs oauth2CertificateSet
|
|
||||||
resp, err := http.Get(uri)
|
|
||||||
if err != nil {
|
|
||||||
return certs, 0, errors.Wrapf(err, "failed to connect to %s", uri)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
m := make(map[string]string)
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
|
||||||
return certs, 0, errors.Wrapf(err, "error reading %s", uri)
|
|
||||||
}
|
|
||||||
for k, v := range m {
|
|
||||||
block, _ := pem.Decode([]byte(v))
|
|
||||||
if block == nil || block.Type != "CERTIFICATE" {
|
|
||||||
return certs, 0, errors.Wrapf(err, "error parsing certificate %s from %s", k, uri)
|
|
||||||
}
|
|
||||||
cert, err := x509.ParseCertificate(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
return certs, 0, errors.Wrapf(err, "error parsing certificate %s from %s", k, uri)
|
|
||||||
}
|
|
||||||
certs.Certificates = append(certs.Certificates, oauth2Certificate{
|
|
||||||
ID: k,
|
|
||||||
Certificate: cert,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return certs, getCacheAge(resp.Header.Get("cache-control")), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCacheAge(cacheControl string) time.Duration {
|
func getCacheAge(cacheControl string) time.Duration {
|
||||||
age := defaultCacheAge
|
age := defaultCacheAge
|
||||||
if len(cacheControl) > 0 {
|
if len(cacheControl) > 0 {
|
||||||
|
|
Loading…
Add table
Reference in a new issue