forked from TrueCloudLab/certificates
Add identity token for all Azure cloud environments
* Azure Public Cloud (default) * Azure China Cloud * Azure US Gov Cloud * Azure German Cloud
This commit is contained in:
parent
b8ee206f71
commit
b2c2eec76b
3 changed files with 119 additions and 19 deletions
|
@ -26,7 +26,12 @@ import (
|
|||
const azureOIDCBaseURL = "https://login.microsoftonline.com"
|
||||
|
||||
//nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance.
|
||||
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F"
|
||||
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token"
|
||||
|
||||
const azureIdentityTokenAPIVersion = "2018-02-01"
|
||||
|
||||
// azureInstanceComputeURL is the URL to get the instance compute metadata.
|
||||
const azureInstanceComputeURL = "http://169.254.169.254/metadata/instance/compute/azEnvironment"
|
||||
|
||||
// azureDefaultAudience is the default audience used.
|
||||
const azureDefaultAudience = "https://management.azure.com/"
|
||||
|
@ -35,15 +40,25 @@ const azureDefaultAudience = "https://management.azure.com/"
|
|||
// Using case insensitive as resourceGroups appears as resourcegroups.
|
||||
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
|
||||
|
||||
// azureEnvironments is the list of all Azure environments.
|
||||
var azureEnvironments = map[string]string{
|
||||
"AzurePublicCloud": "https://management.azure.com/",
|
||||
"AzureUSGovernmentCloud": "https://management.usgovcloudapi.net/",
|
||||
"AzureChinaCloud": "https://management.chinacloudapi.cn/",
|
||||
"AzureGermanCloud": "https://management.microsoftazure.de/",
|
||||
}
|
||||
|
||||
type azureConfig struct {
|
||||
oidcDiscoveryURL string
|
||||
identityTokenURL string
|
||||
instanceComputeURL string
|
||||
}
|
||||
|
||||
func newAzureConfig(tenantID string) *azureConfig {
|
||||
return &azureConfig{
|
||||
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
|
||||
identityTokenURL: azureIdentityTokenURL,
|
||||
instanceComputeURL: azureInstanceComputeURL,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -167,11 +182,28 @@ func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
|
|||
// Initialize the config if this method is used from the cli.
|
||||
p.assertConfig()
|
||||
|
||||
// default to AzurePublicCloud to keep existing behavior
|
||||
identityTokenResource := azureEnvironments["AzurePublicCloud"]
|
||||
environment, err := p.getAzureEnvironment()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error getting azure environment")
|
||||
}
|
||||
|
||||
if resource, ok := azureEnvironments[environment]; ok {
|
||||
identityTokenResource = resource
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error creating request")
|
||||
}
|
||||
req.Header.Set("Metadata", "true")
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Add("resource", identityTokenResource)
|
||||
query.Add("api-version", azureIdentityTokenAPIVersion)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?")
|
||||
|
@ -444,3 +476,33 @@ func (p *Azure) assertConfig() {
|
|||
p.config = newAzureConfig(p.TenantID)
|
||||
}
|
||||
}
|
||||
|
||||
// getAzureEnvironment returns the Azure environment for the current instance
|
||||
func (p *Azure) getAzureEnvironment() (string, error) {
|
||||
req, err := http.NewRequest("GET", p.config.instanceComputeURL, http.NoBody)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error creating request")
|
||||
}
|
||||
req.Header.Add("Metadata", "True")
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Add("format", "text")
|
||||
query.Add("api-version", "2021-02-01")
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error getting azure instance environment, are you in a Azure VM?")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error reading azure environment response")
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return "", errors.Errorf("error getting azure environment: status=%d, response=%s", resp.StatusCode, b)
|
||||
}
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
|
|
@ -100,7 +100,14 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
|||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
srvIdentity := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
wantResource := r.URL.Query().Get("want_resource")
|
||||
resource := r.URL.Query().Get("resource")
|
||||
if wantResource == "" || resource != wantResource {
|
||||
http.Error(w, fmt.Sprintf("Azure query param resource = %s, wantResource %s", resource, wantResource), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/bad-request":
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
|
@ -111,7 +118,27 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
|||
fmt.Fprintf(w, `{"access_token":"%s"}`, t1)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
defer srvIdentity.Close()
|
||||
|
||||
srvInstance := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/bad-request":
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
case "/AzureChinaCloud":
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzureChinaCloud"))
|
||||
case "/AzureGermanCloud":
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzureGermanCloud"))
|
||||
case "/AzureUSGovernmentCloud":
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzureUSGovernmentCloud"))
|
||||
default:
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzurePublicCloud"))
|
||||
}
|
||||
}))
|
||||
defer srvInstance.Close()
|
||||
|
||||
type args struct {
|
||||
subject string
|
||||
|
@ -122,18 +149,25 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
|||
azure *Azure
|
||||
args args
|
||||
identityTokenURL string
|
||||
instanceComputeURL string
|
||||
wantEnvironment string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{"subject", "caURL"}, srv.URL, t1, false},
|
||||
{"fail request", p1, args{"subject", "caURL"}, srv.URL + "/bad-request", "", true},
|
||||
{"fail unmarshal", p1, args{"subject", "caURL"}, srv.URL + "/bad-json", "", true},
|
||||
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", "", true},
|
||||
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", "", true},
|
||||
{"ok", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false},
|
||||
{"ok azure china", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false},
|
||||
{"ok azure germany", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureGermanCloud", t1, false},
|
||||
{"ok azure us gov", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureUSGovernmentCloud", t1, false},
|
||||
{"fail instance request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL + "/bad-request", "AzurePublicCloud", "", true},
|
||||
{"fail request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||
{"fail unmarshal", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-json", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", srvInstance.URL, "AzurePublicCloud", "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.azure.config.identityTokenURL = tt.identityTokenURL
|
||||
tt.azure.config.identityTokenURL = tt.identityTokenURL + "?want_resource=" + azureEnvironments[tt.wantEnvironment]
|
||||
tt.azure.config.instanceComputeURL = tt.instanceComputeURL + "/" + tt.wantEnvironment
|
||||
got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
|
|
@ -665,6 +665,9 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
|||
AccessToken: tok,
|
||||
})
|
||||
}
|
||||
case "/metadata/instance/compute/azEnvironment":
|
||||
w.Header().Add("Content-Type", "text/plain")
|
||||
w.Write([]byte("AzurePublicCloud"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
@ -672,6 +675,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
|||
srv.Start()
|
||||
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
|
||||
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
|
||||
az.config.instanceComputeURL = srv.URL + "/metadata/instance/compute/azEnvironment"
|
||||
return az, srv, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue