Add support for loading azure tokens by tenant id.

This commit is contained in:
Mariano Cano 2019-05-08 15:39:50 -07:00
parent 803d81d332
commit 89eeada2a2
2 changed files with 17 additions and 2 deletions

View file

@ -49,6 +49,7 @@ func TestAWS_GetTokenID(t *testing.T) {
t1, err := p1.GetIdentityToken() t1, err := p1.GetIdentityToken()
assert.FatalError(t, err) assert.FatalError(t, err)
t.Error(t1)
_, claims, err := parseAWSToken(t1) _, claims, err := parseAWSToken(t1)
assert.FatalError(t, err) assert.FatalError(t, err)
sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID))) sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID)))

View file

@ -33,6 +33,14 @@ func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// loadByTokenPayload is a payload used to extract the id used to load the
// provisioner.
type loadByTokenPayload struct {
jose.Claims
AuthorizedParty string `json:"azp"` // OIDC client id
TenantID string `json:"tid"` // Microsoft Azure tenant id
}
// Collection is a memory map of provisioners. // Collection is a memory map of provisioners.
type Collection struct { type Collection struct {
byID *sync.Map byID *sync.Map
@ -65,8 +73,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID) return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
} }
// The ID will be just the clientID stored in azp or aud. // The ID will be just the clientID stored in azp, aud or tid.
var payload openIDPayload var payload loadByTokenPayload
if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil { if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil {
return nil, false return nil, false
} }
@ -80,6 +88,12 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
return p, ok return p, ok
} }
} }
// Try with tid (Azure)
if payload.TenantID != "" {
if p, ok := c.Load(payload.TenantID); ok {
return p, ok
}
}
// Fallback to aud (GCP) // Fallback to aud (GCP)
return c.Load(payload.Audience[0]) return c.Load(payload.Audience[0])
} }