diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 4b161d9c..fcfbab27 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -26,7 +26,12 @@ import ( const azureOIDCBaseURL = "https://login.microsoftonline.com" //nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance. -const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F" +const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token" + +const azureIdentityTokenAPIVersion = "2018-02-01" + +// azureInstanceComputeURL is the URL to get the instance compute metadata. +const azureInstanceComputeURL = "http://169.254.169.254/metadata/instance/compute/azEnvironment" // azureDefaultAudience is the default audience used. const azureDefaultAudience = "https://management.azure.com/" @@ -35,15 +40,27 @@ const azureDefaultAudience = "https://management.azure.com/" // Using case insensitive as resourceGroups appears as resourcegroups. var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`) +// azureEnvironments is the list of all Azure environments. +var azureEnvironments = map[string]string{ + "AzurePublicCloud": "https://management.azure.com/", + "AzureCloud": "https://management.azure.com/", + "AzureUSGovernmentCloud": "https://management.usgovcloudapi.net/", + "AzureUSGovernment": "https://management.usgovcloudapi.net/", + "AzureChinaCloud": "https://management.chinacloudapi.cn/", + "AzureGermanCloud": "https://management.microsoftazure.de/", +} + type azureConfig struct { - oidcDiscoveryURL string - identityTokenURL string + oidcDiscoveryURL string + identityTokenURL string + instanceComputeURL string } func newAzureConfig(tenantID string) *azureConfig { return &azureConfig{ - oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration", - identityTokenURL: azureIdentityTokenURL, + oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration", + identityTokenURL: azureIdentityTokenURL, + instanceComputeURL: azureInstanceComputeURL, } } @@ -103,6 +120,7 @@ type Azure struct { oidcConfig openIDConfiguration keyStore *keyStore ctl *Controller + environment string } // GetID returns the provisioner unique identifier. @@ -167,11 +185,30 @@ func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) { // Initialize the config if this method is used from the cli. p.assertConfig() + // default to AzurePublicCloud to keep existing behavior + identityTokenResource := azureEnvironments["AzurePublicCloud"] + + var err error + p.environment, err = p.getAzureEnvironment() + if err != nil { + return "", errors.Wrap(err, "error getting azure environment") + } + + if resource, ok := azureEnvironments[p.environment]; ok { + identityTokenResource = resource + } + req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody) if err != nil { return "", errors.Wrap(err, "error creating request") } req.Header.Set("Metadata", "true") + + query := req.URL.Query() + query.Add("resource", identityTokenResource) + query.Add("api-version", azureIdentityTokenAPIVersion) + req.URL.RawQuery = query.Encode() + resp, err := http.DefaultClient.Do(req) if err != nil { return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?") @@ -444,3 +481,37 @@ func (p *Azure) assertConfig() { p.config = newAzureConfig(p.TenantID) } } + +// getAzureEnvironment returns the Azure environment for the current instance +func (p *Azure) getAzureEnvironment() (string, error) { + if p.environment != "" { + return p.environment, nil + } + + req, err := http.NewRequest("GET", p.config.instanceComputeURL, http.NoBody) + if err != nil { + return "", errors.Wrap(err, "error creating request") + } + req.Header.Add("Metadata", "True") + + query := req.URL.Query() + query.Add("format", "text") + query.Add("api-version", "2021-02-01") + req.URL.RawQuery = query.Encode() + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", errors.Wrap(err, "error getting azure instance environment, are you in a Azure VM?") + } + defer resp.Body.Close() + + b, err := io.ReadAll(resp.Body) + if err != nil { + return "", errors.Wrap(err, "error reading azure environment response") + } + if resp.StatusCode >= 400 { + return "", errors.Errorf("error getting azure environment: status=%d, response=%s", resp.StatusCode, b) + } + + return string(b), nil +} diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 84f2ebbf..51d46c5a 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -100,7 +100,14 @@ func TestAzure_GetIdentityToken(t *testing.T) { time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srvIdentity := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wantResource := r.URL.Query().Get("want_resource") + resource := r.URL.Query().Get("resource") + if wantResource == "" || resource != wantResource { + http.Error(w, fmt.Sprintf("Azure query param resource = %s, wantResource %s", resource, wantResource), http.StatusBadRequest) + return + } + switch r.URL.Path { case "/bad-request": http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) @@ -111,29 +118,58 @@ func TestAzure_GetIdentityToken(t *testing.T) { fmt.Fprintf(w, `{"access_token":"%s"}`, t1) } })) - defer srv.Close() + defer srvIdentity.Close() + + srvInstance := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/bad-request": + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + case "/AzureChinaCloud": + w.Header().Add("Content-Type", "text/plain") + w.Write([]byte("AzureChinaCloud")) + case "/AzureGermanCloud": + w.Header().Add("Content-Type", "text/plain") + w.Write([]byte("AzureGermanCloud")) + case "/AzureUSGovernmentCloud": + w.Header().Add("Content-Type", "text/plain") + w.Write([]byte("AzureUSGovernmentCloud")) + default: + w.Header().Add("Content-Type", "text/plain") + w.Write([]byte("AzurePublicCloud")) + } + })) + defer srvInstance.Close() type args struct { subject string caURL string } tests := []struct { - name string - azure *Azure - args args - identityTokenURL string - want string - wantErr bool + name string + azure *Azure + args args + identityTokenURL string + instanceComputeURL string + wantEnvironment string + want string + wantErr bool }{ - {"ok", p1, args{"subject", "caURL"}, srv.URL, t1, false}, - {"fail request", p1, args{"subject", "caURL"}, srv.URL + "/bad-request", "", true}, - {"fail unmarshal", p1, args{"subject", "caURL"}, srv.URL + "/bad-json", "", true}, - {"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", "", true}, - {"fail connect", p1, args{"subject", "caURL"}, "foobarzar", "", true}, + {"ok", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false}, + {"ok azure china", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false}, + {"ok azure germany", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureGermanCloud", t1, false}, + {"ok azure us gov", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureUSGovernmentCloud", t1, false}, + {"fail instance request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL + "/bad-request", "AzurePublicCloud", "", true}, + {"fail request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL, "AzurePublicCloud", "", true}, + {"fail unmarshal", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-json", srvInstance.URL, "AzurePublicCloud", "", true}, + {"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", srvInstance.URL, "AzurePublicCloud", "", true}, + {"fail connect", p1, args{"subject", "caURL"}, "foobarzar", srvInstance.URL, "AzurePublicCloud", "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tt.azure.config.identityTokenURL = tt.identityTokenURL + // reset environment between tests to avoid caching issues + p1.environment = "" + tt.azure.config.identityTokenURL = tt.identityTokenURL + "?want_resource=" + azureEnvironments[tt.wantEnvironment] + tt.azure.config.instanceComputeURL = tt.instanceComputeURL + "/" + tt.wantEnvironment got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL) if (err != nil) != tt.wantErr { t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index f0e6949f..55fdfe6f 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -665,6 +665,9 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) { AccessToken: tok, }) } + case "/metadata/instance/compute/azEnvironment": + w.Header().Add("Content-Type", "text/plain") + w.Write([]byte("AzurePublicCloud")) default: http.NotFound(w, r) } @@ -672,6 +675,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) { srv.Start() az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration" az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token" + az.config.instanceComputeURL = srv.URL + "/metadata/instance/compute/azEnvironment" return az, srv, nil }