Support Azure tokens from managed identities not associated with a VM

This commit is contained in:
vijayjt 2022-03-07 11:24:58 +00:00
parent ea454f9dfc
commit e699244291
3 changed files with 31 additions and 17 deletions

View file

@ -30,7 +30,7 @@ const azureDefaultAudience = "https://management.azure.com/"
// azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim. // azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim.
// Using case insensitive as resourceGroups appears as resourcegroups. // Using case insensitive as resourceGroups appears as resourcegroups.
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`) var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
type azureConfig struct { type azureConfig struct {
oidcDiscoveryURL string oidcDiscoveryURL string
@ -263,11 +263,19 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str
} }
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 { if len(re) != 5 {
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID) return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
} }
var subscription, group, name string
identityObjectID := claims.ObjectID identityObjectID := claims.ObjectID
subscription, group, name := re[1], re[2], re[3]
if strings.Contains(claims.XMSMirID, "virtualMachines") {
subscription, group, name = re[1], re[2], re[4]
} else {
// This is not a VM resource ID so we don't have the VM name so set that to the empty string
subscription, group, name = re[1], re[2], ""
}
return &claims, name, group, subscription, identityObjectID, nil return &claims, name, group, subscription, identityObjectID, nil
} }

View file

@ -95,7 +95,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
@ -237,7 +237,7 @@ func TestAzure_authorizeToken(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), jwk) time.Now(), jwk)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -252,7 +252,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -267,7 +267,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
"foo", "subscriptionID", "resourceGroup", "virtualMachine", "foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -321,7 +321,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -437,28 +437,28 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience", failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0]) time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0]) time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), badKey) time.Now(), badKey)
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -671,7 +671,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
w.Header().Add("Cache-Control", "max-age=5") w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(az.keyStore.keySet)) writeJSON(w, getPublic(az.keyStore.keySet))
case "/metadata/identity/oauth2/token": case "/metadata/identity/oauth2/token":
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0]) tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0])
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} else { } else {
@ -1009,7 +1009,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName string, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner( sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
@ -1017,6 +1017,12 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
if err != nil { if err != nil {
return "", err return "", err
} }
var xmsMirID string
if resourceType == "vm" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName)
} else if resourceType == "uai" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName)
}
claims := azurePayload{ claims := azurePayload{
Claims: jose.Claims{ Claims: jose.Claims{
@ -1034,7 +1040,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
ObjectID: "the-oid", ObjectID: "the-oid",
TenantID: tenantID, TenantID: tenantID,
Version: "the-version", Version: "the-version",
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine), XMSMirID: xmsMirID,
} }
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }