Add identity token for all Azure cloud environments

* Azure Public Cloud (default)
* Azure China Cloud
* Azure US Gov Cloud
* Azure German Cloud
This commit is contained in:
Remi Vichery 2023-03-06 17:33:14 -08:00
parent b8ee206f71
commit b2c2eec76b
No known key found for this signature in database
GPG key ID: B0CE1B4CEA178D90
3 changed files with 119 additions and 19 deletions

View file

@ -26,7 +26,12 @@ import (
const azureOIDCBaseURL = "https://login.microsoftonline.com" const azureOIDCBaseURL = "https://login.microsoftonline.com"
//nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance. //nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance.
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F" const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token"
const azureIdentityTokenAPIVersion = "2018-02-01"
// azureInstanceComputeURL is the URL to get the instance compute metadata.
const azureInstanceComputeURL = "http://169.254.169.254/metadata/instance/compute/azEnvironment"
// azureDefaultAudience is the default audience used. // azureDefaultAudience is the default audience used.
const azureDefaultAudience = "https://management.azure.com/" const azureDefaultAudience = "https://management.azure.com/"
@ -35,15 +40,25 @@ const azureDefaultAudience = "https://management.azure.com/"
// Using case insensitive as resourceGroups appears as resourcegroups. // Using case insensitive as resourceGroups appears as resourcegroups.
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`) var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
// azureEnvironments is the list of all Azure environments.
var azureEnvironments = map[string]string{
"AzurePublicCloud": "https://management.azure.com/",
"AzureUSGovernmentCloud": "https://management.usgovcloudapi.net/",
"AzureChinaCloud": "https://management.chinacloudapi.cn/",
"AzureGermanCloud": "https://management.microsoftazure.de/",
}
type azureConfig struct { type azureConfig struct {
oidcDiscoveryURL string oidcDiscoveryURL string
identityTokenURL string identityTokenURL string
instanceComputeURL string
} }
func newAzureConfig(tenantID string) *azureConfig { func newAzureConfig(tenantID string) *azureConfig {
return &azureConfig{ return &azureConfig{
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration", oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
identityTokenURL: azureIdentityTokenURL, identityTokenURL: azureIdentityTokenURL,
instanceComputeURL: azureInstanceComputeURL,
} }
} }
@ -167,11 +182,28 @@ func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
// Initialize the config if this method is used from the cli. // Initialize the config if this method is used from the cli.
p.assertConfig() p.assertConfig()
// default to AzurePublicCloud to keep existing behavior
identityTokenResource := azureEnvironments["AzurePublicCloud"]
environment, err := p.getAzureEnvironment()
if err != nil {
return "", errors.Wrap(err, "error getting azure environment")
}
if resource, ok := azureEnvironments[environment]; ok {
identityTokenResource = resource
}
req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody) req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error creating request") return "", errors.Wrap(err, "error creating request")
} }
req.Header.Set("Metadata", "true") req.Header.Set("Metadata", "true")
query := req.URL.Query()
query.Add("resource", identityTokenResource)
query.Add("api-version", azureIdentityTokenAPIVersion)
req.URL.RawQuery = query.Encode()
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?") return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?")
@ -444,3 +476,33 @@ func (p *Azure) assertConfig() {
p.config = newAzureConfig(p.TenantID) p.config = newAzureConfig(p.TenantID)
} }
} }
// getAzureEnvironment returns the Azure environment for the current instance
func (p *Azure) getAzureEnvironment() (string, error) {
req, err := http.NewRequest("GET", p.config.instanceComputeURL, http.NoBody)
if err != nil {
return "", errors.Wrap(err, "error creating request")
}
req.Header.Add("Metadata", "True")
query := req.URL.Query()
query.Add("format", "text")
query.Add("api-version", "2021-02-01")
req.URL.RawQuery = query.Encode()
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", errors.Wrap(err, "error getting azure instance environment, are you in a Azure VM?")
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return "", errors.Wrap(err, "error reading azure environment response")
}
if resp.StatusCode >= 400 {
return "", errors.Errorf("error getting azure environment: status=%d, response=%s", resp.StatusCode, b)
}
return string(b), nil
}

View file

@ -100,7 +100,14 @@ func TestAzure_GetIdentityToken(t *testing.T) {
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srvIdentity := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wantResource := r.URL.Query().Get("want_resource")
resource := r.URL.Query().Get("resource")
if wantResource == "" || resource != wantResource {
http.Error(w, fmt.Sprintf("Azure query param resource = %s, wantResource %s", resource, wantResource), http.StatusBadRequest)
return
}
switch r.URL.Path { switch r.URL.Path {
case "/bad-request": case "/bad-request":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
@ -111,29 +118,56 @@ func TestAzure_GetIdentityToken(t *testing.T) {
fmt.Fprintf(w, `{"access_token":"%s"}`, t1) fmt.Fprintf(w, `{"access_token":"%s"}`, t1)
} }
})) }))
defer srv.Close() defer srvIdentity.Close()
srvInstance := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/bad-request":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case "/AzureChinaCloud":
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzureChinaCloud"))
case "/AzureGermanCloud":
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzureGermanCloud"))
case "/AzureUSGovernmentCloud":
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzureUSGovernmentCloud"))
default:
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzurePublicCloud"))
}
}))
defer srvInstance.Close()
type args struct { type args struct {
subject string subject string
caURL string caURL string
} }
tests := []struct { tests := []struct {
name string name string
azure *Azure azure *Azure
args args args args
identityTokenURL string identityTokenURL string
want string instanceComputeURL string
wantErr bool wantEnvironment string
want string
wantErr bool
}{ }{
{"ok", p1, args{"subject", "caURL"}, srv.URL, t1, false}, {"ok", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false},
{"fail request", p1, args{"subject", "caURL"}, srv.URL + "/bad-request", "", true}, {"ok azure china", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false},
{"fail unmarshal", p1, args{"subject", "caURL"}, srv.URL + "/bad-json", "", true}, {"ok azure germany", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureGermanCloud", t1, false},
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", "", true}, {"ok azure us gov", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureUSGovernmentCloud", t1, false},
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", "", true}, {"fail instance request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL + "/bad-request", "AzurePublicCloud", "", true},
{"fail request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL, "AzurePublicCloud", "", true},
{"fail unmarshal", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-json", srvInstance.URL, "AzurePublicCloud", "", true},
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", srvInstance.URL, "AzurePublicCloud", "", true},
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", srvInstance.URL, "AzurePublicCloud", "", true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tt.azure.config.identityTokenURL = tt.identityTokenURL tt.azure.config.identityTokenURL = tt.identityTokenURL + "?want_resource=" + azureEnvironments[tt.wantEnvironment]
tt.azure.config.instanceComputeURL = tt.instanceComputeURL + "/" + tt.wantEnvironment
got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL) got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)

View file

@ -665,6 +665,9 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
AccessToken: tok, AccessToken: tok,
}) })
} }
case "/metadata/instance/compute/azEnvironment":
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzurePublicCloud"))
default: default:
http.NotFound(w, r) http.NotFound(w, r)
} }
@ -672,6 +675,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
srv.Start() srv.Start()
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration" az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token" az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
az.config.instanceComputeURL = srv.URL + "/metadata/instance/compute/azEnvironment"
return az, srv, nil return az, srv, nil
} }