From 89eeada2a227f16735bb43d8de7656f66e9811d1 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 8 May 2019 15:39:50 -0700 Subject: [PATCH] Add support for loading azure tokens by tenant id. --- authority/provisioner/aws_test.go | 1 + authority/provisioner/collection.go | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 429e583e..78329838 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -49,6 +49,7 @@ func TestAWS_GetTokenID(t *testing.T) { t1, err := p1.GetIdentityToken() assert.FatalError(t, err) + t.Error(t1) _, claims, err := parseAWSToken(t1) assert.FatalError(t, err) sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID))) diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index e34a2fcf..c3c6518c 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -33,6 +33,14 @@ func (p provisionerSlice) Len() int { return len(p) } func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +// loadByTokenPayload is a payload used to extract the id used to load the +// provisioner. +type loadByTokenPayload struct { + jose.Claims + AuthorizedParty string `json:"azp"` // OIDC client id + TenantID string `json:"tid"` // Microsoft Azure tenant id +} + // Collection is a memory map of provisioners. type Collection struct { byID *sync.Map @@ -65,8 +73,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID) } - // The ID will be just the clientID stored in azp or aud. - var payload openIDPayload + // The ID will be just the clientID stored in azp, aud or tid. + var payload loadByTokenPayload if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil { return nil, false } @@ -80,6 +88,12 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) return p, ok } } + // Try with tid (Azure) + if payload.TenantID != "" { + if p, ok := c.Load(payload.TenantID); ok { + return p, ok + } + } // Fallback to aud (GCP) return c.Load(payload.Audience[0]) }