diff --git a/backend/azureblob/azureblob.go b/backend/azureblob/azureblob.go index 853a8e5e4..e454626e4 100644 --- a/backend/azureblob/azureblob.go +++ b/backend/azureblob/azureblob.go @@ -102,6 +102,31 @@ for more details. }, { Name: "sas_url", Help: "SAS URL for container level access only\n(leave blank if using account/key or Emulator)", + }, { + Name: "use_msi", + Help: `Use a managed service identity to authenticate (only works in Azure) + +When true, use a [managed service identity](https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/) +to authenticate to Azure Storage instead of a SAS token or account key. + +If the VM(SS) on which this program is running has a system-assigned identity, it will +be used by default. If the resource has no system-assigned but exactly one user-assigned identity, +the user-assigned identity will be used by default. If the resource has multiple user-assigned +identities, the identity to use must be explicitly specified using exactly one of the msi_object_id, +msi_client_id, or msi_mi_res_id parameters.`, + Default: false, + }, { + Name: "msi_object_id", + Help: "Object ID of the user-assigned MSI to use, if any. Leave blank if msi_client_id or msi_mi_res_id specified.", + Advanced: true, + }, { + Name: "msi_client_id", + Help: "Object ID of the user-assigned MSI to use, if any. Leave blank if msi_object_id or msi_mi_res_id specified.", + Advanced: true, + }, { + Name: "msi_mi_res_id", + Help: "Azure resource ID of the user-assigned MSI to use, if any. Leave blank if msi_client_id or msi_object_id specified.", + Advanced: true, }, { Name: "use_emulator", Help: "Uses local storage emulator if provided as 'true' (leave blank if using real azure storage endpoint)", @@ -210,6 +235,10 @@ type Options struct { Account string `config:"account"` ServicePrincipalFile string `config:"service_principal_file"` Key string `config:"key"` + UseMSI bool `config:"use_msi"` + MSIObjectID string `config:"msi_object_id"` + MSIClientID string `config:"msi_client_id"` + MSIResourceID string `config:"msi_mi_res_id"` Endpoint string `config:"endpoint"` SASURL string `config:"sas_url"` UploadCutoff fs.SizeSuffix `config:"upload_cutoff"` @@ -240,6 +269,7 @@ type Fs struct { isLimited bool // if limited to one container cache *bucket.Cache // cache for container creation status pacer *fs.Pacer // To pace and retry the API calls + imdsPacer *fs.Pacer // Same but for IMDS uploadToken *pacer.TokenDispenser // control concurrency pool *pool.Pool // memory pool } @@ -342,6 +372,8 @@ func (f *Fs) shouldRetry(err error) (bool, error) { return true, err } } + } else if httpErr, ok := err.(httpError); ok { + return fserrors.ShouldRetryHTTP(httpErr.Response, retryErrorCodes), err } return fserrors.ShouldRetry(err), err } @@ -502,6 +534,7 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e opt: *opt, ci: ci, pacer: fs.NewPacer(ctx, pacer.NewS3(pacer.MinSleep(minSleep), pacer.MaxSleep(maxSleep), pacer.DecayConstant(decayConstant))), + imdsPacer: fs.NewPacer(ctx, pacer.NewAzureIMDS()), uploadToken: pacer.NewTokenDispenser(ci.Transfers), client: fshttp.NewClient(ctx), cache: bucket.NewCache(), @@ -513,6 +546,7 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e opt.MemoryPoolUseMmap, ), } + f.imdsPacer.SetRetries(5) // per IMDS documentation f.setRoot(root) f.features = (&fs.Features{ ReadMimeType: true, @@ -539,16 +573,81 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e } pipeline := f.newPipeline(credential, azblob.PipelineOptions{Retry: azblob.RetryOptions{TryTimeout: maxTryTimeout}}) serviceURL = azblob.NewServiceURL(*u, pipeline) - case opt.Account != "" && opt.Key != "": - credential, err := azblob.NewSharedKeyCredential(opt.Account, opt.Key) + case opt.UseMSI: + var token adal.Token + var userMSI *userMSI = &userMSI{} + if len(opt.MSIClientID) > 0 || len(opt.MSIObjectID) > 0 || len(opt.MSIResourceID) > 0 { + // Specifying a user-assigned identity. Exactly one of the above IDs must be specified. + // Validate and ensure exactly one is set. (To do: better validation.) + if len(opt.MSIClientID) > 0 { + if len(opt.MSIObjectID) > 0 || len(opt.MSIResourceID) > 0 { + return nil, errors.New("more than one user-assigned identity ID is set") + } + userMSI.Type = msiClientID + userMSI.Value = opt.MSIClientID + } + if len(opt.MSIObjectID) > 0 { + if len(opt.MSIClientID) > 0 || len(opt.MSIResourceID) > 0 { + return nil, errors.New("more than one user-assigned identity ID is set") + } + userMSI.Type = msiObjectID + userMSI.Value = opt.MSIObjectID + } + if len(opt.MSIResourceID) > 0 { + if len(opt.MSIClientID) > 0 || len(opt.MSIObjectID) > 0 { + return nil, errors.New("more than one user-assigned identity ID is set") + } + userMSI.Type = msiResourceID + userMSI.Value = opt.MSIResourceID + } + } else { + userMSI = nil + } + err = f.imdsPacer.Call(func() (bool, error) { + // Retry as specified by the documentation: + // https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance + token, err = GetMSIToken(ctx, userMSI) + return f.shouldRetry(err) + }) + if err != nil { - return nil, errors.Wrapf(err, "Failed to parse credentials") + return nil, errors.Wrapf(err, "Failed to acquire MSI token") } u, err = url.Parse(fmt.Sprintf("https://%s.%s", opt.Account, opt.Endpoint)) if err != nil { return nil, errors.Wrap(err, "failed to make azure storage url from account and endpoint") } + credential := azblob.NewTokenCredential(token.AccessToken, func(credential azblob.TokenCredential) time.Duration { + fs.Debugf(f, "Token refresher called.") + var refreshedToken adal.Token + err := f.imdsPacer.Call(func() (bool, error) { + refreshedToken, err = GetMSIToken(ctx, userMSI) + return f.shouldRetry(err) + }) + if err != nil { + // Failed to refresh. + return 0 + } + credential.SetToken(refreshedToken.AccessToken) + now := time.Now().UTC() + // Refresh one minute before expiry. + refreshAt := refreshedToken.Expires().UTC().Add(-1 * time.Minute) + fs.Debugf(f, "Acquired new token that expires at %v; refreshing in %d s", refreshedToken.Expires(), + int(refreshAt.Sub(now).Seconds())) + if now.After(refreshAt) { + // Acquired a causality violation. + return 0 + } + return refreshAt.Sub(now) + }) + pipeline := f.newPipeline(credential, azblob.PipelineOptions{Retry: azblob.RetryOptions{TryTimeout: maxTryTimeout}}) + serviceURL = azblob.NewServiceURL(*u, pipeline) + case opt.Account != "" && opt.Key != "": + credential, err := azblob.NewSharedKeyCredential(opt.Account, opt.Key) + if err != nil { + return nil, errors.Wrapf(err, "Failed to parse credentials") + } pipeline := f.newPipeline(credential, azblob.PipelineOptions{Retry: azblob.RetryOptions{TryTimeout: maxTryTimeout}}) serviceURL = azblob.NewServiceURL(*u, pipeline) case opt.SASURL != "": @@ -590,7 +689,7 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e pipe := f.newPipeline(azblob.NewTokenCredential("", tokenRefresher), options) serviceURL = azblob.NewServiceURL(*u, pipe) default: - return nil, errors.New("Need account+key or connectionString or sasURL or servicePrincipalFile") + return nil, errors.New("No authentication method configured") } f.svcURL = &serviceURL diff --git a/backend/azureblob/imds.go b/backend/azureblob/imds.go new file mode 100644 index 000000000..4af267e94 --- /dev/null +++ b/backend/azureblob/imds.go @@ -0,0 +1,137 @@ +// +build !plan9,!solaris,!js,go1.13 + +package azureblob + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + + "github.com/Azure/go-autorest/autorest/adal" + "github.com/pkg/errors" + "github.com/rclone/rclone/fs" + "github.com/rclone/rclone/fs/fshttp" +) + +const ( + azureResource = "https://storage.azure.com" + imdsAPIVersion = "2018-02-01" + msiEndpointDefault = "http://169.254.169.254/metadata/identity/oauth2/token" +) + +// This custom type is used to add the port the test server has bound to +// to the request context. +type testPortKey string + +type msiIdentifierType int + +const ( + msiClientID msiIdentifierType = iota + msiObjectID + msiResourceID +) + +type userMSI struct { + Type msiIdentifierType + Value string +} + +type httpError struct { + Response *http.Response +} + +func (e httpError) Error() string { + return fmt.Sprintf("HTTP error %v (%v)", e.Response.StatusCode, e.Response.Status) +} + +// GetMSIToken attempts to obtain an MSI token from the Azure Instance +// Metadata Service. +func GetMSIToken(ctx context.Context, identity *userMSI) (adal.Token, error) { + // Attempt to get an MSI token; silently continue if unsuccessful. + // This code has been lovingly stolen from azcopy's OAuthTokenManager. + result := adal.Token{} + req, err := http.NewRequestWithContext(ctx, "GET", msiEndpointDefault, nil) + if err != nil { + fs.Debugf(nil, "Failed to create request: %v", err) + return result, err + } + params := req.URL.Query() + params.Set("resource", azureResource) + params.Set("api-version", imdsAPIVersion) + + // Specify user-assigned identity if requested. + if identity != nil { + switch identity.Type { + case msiClientID: + params.Set("client_id", identity.Value) + case msiObjectID: + params.Set("object_id", identity.Value) + case msiResourceID: + params.Set("mi_res_id", identity.Value) + default: + // If this happens, the calling function and this one don't agree on + // what valid ID types exist. + return result, fmt.Errorf("unknown MSI identity type specified") + } + } + req.URL.RawQuery = params.Encode() + + // The Metadata header is required by all calls to IMDS. + req.Header.Set("Metadata", "true") + + // If this function is run in a test, query the test server instead of IMDS. + testPort, isTest := ctx.Value(testPortKey("testPort")).(int) + if isTest { + req.URL.Host = fmt.Sprintf("localhost:%d", testPort) + req.Host = req.URL.Host + } + + // Send request + httpClient := fshttp.NewClient(ctx) + resp, err := httpClient.Do(req) + if err != nil { + return result, errors.Wrap(err, "MSI is not enabled on this VM") + } + defer func() { // resp and Body should not be nil + _, err = io.Copy(ioutil.Discard, resp.Body) + if err != nil { + fs.Debugf(nil, "Unable to drain IMDS response: %v", err) + } + err = resp.Body.Close() + if err != nil { + fs.Debugf(nil, "Unable to close IMDS response: %v", err) + } + }() + // Check if the status code indicates success + // The request returns 200 currently, add 201 and 202 as well for possible extension. + switch resp.StatusCode { + case 200, 201, 202: + break + default: + body, _ := ioutil.ReadAll(resp.Body) + fs.Errorf(nil, "Couldn't obtain OAuth token from IMDS; server returned status code %d and body: %v", resp.StatusCode, string(body)) + return result, httpError{Response: resp} + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return result, errors.Wrap(err, "Couldn't read IMDS response") + } + // Remove BOM, if any. azcopy does this so I'm following along. + b = bytes.TrimPrefix(b, []byte("\xef\xbb\xbf")) + + // This would be a good place to persist the token if a large number of rclone + // invocations are being made in a short amount of time. If the token is + // persisted, the azureblob code will need to check for expiry before every + // storage API call. + err = json.Unmarshal(b, &result) + if err != nil { + return result, errors.Wrap(err, "Couldn't unmarshal IMDS response") + } + + return result, nil +} diff --git a/backend/azureblob/imds_test.go b/backend/azureblob/imds_test.go new file mode 100644 index 000000000..315df1bd1 --- /dev/null +++ b/backend/azureblob/imds_test.go @@ -0,0 +1,117 @@ +// +build !plan9,!solaris,!js,go1.13 + +package azureblob + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/Azure/go-autorest/autorest/adal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func handler(t *testing.T, actual *map[string]string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + require.NoError(t, err) + parameters := r.URL.Query() + (*actual)["path"] = r.URL.Path + (*actual)["Metadata"] = r.Header.Get("Metadata") + (*actual)["method"] = r.Method + for paramName := range parameters { + (*actual)[paramName] = parameters.Get(paramName) + } + // Make response. + response := adal.Token{} + responseBytes, err := json.Marshal(response) + require.NoError(t, err) + _, err = w.Write(responseBytes) + require.NoError(t, err) + } +} + +func TestManagedIdentity(t *testing.T) { + // test user-assigned identity specifiers to use + testMSIClientID := "d859b29f-5c9c-42f8-a327-ec1bc6408d79" + testMSIObjectID := "9ffeb650-3ca0-4278-962b-5a38d520591a" + testMSIResourceID := "/subscriptions/fe714c49-b8a4-4d49-9388-96a20daa318f/resourceGroups/somerg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/someidentity" + tests := []struct { + identity *userMSI + identityParameterName string + expectedAbsent []string + }{ + {&userMSI{msiClientID, testMSIClientID}, "client_id", []string{"object_id", "mi_res_id"}}, + {&userMSI{msiObjectID, testMSIObjectID}, "object_id", []string{"client_id", "mi_res_id"}}, + {&userMSI{msiResourceID, testMSIResourceID}, "mi_res_id", []string{"object_id", "client_id"}}, + {nil, "(default)", []string{"object_id", "client_id", "mi_res_id"}}, + } + alwaysExpected := map[string]string{ + "path": "/metadata/identity/oauth2/token", + "resource": "https://storage.azure.com", + "Metadata": "true", + "api-version": "2018-02-01", + "method": "GET", + } + for _, test := range tests { + actual := make(map[string]string, 10) + testServer := httptest.NewServer(handler(t, &actual)) + defer testServer.Close() + testServerPort, err := strconv.Atoi(strings.Split(testServer.URL, ":")[2]) + require.NoError(t, err) + ctx := context.WithValue(context.TODO(), testPortKey("testPort"), testServerPort) + _, err = GetMSIToken(ctx, test.identity) + require.NoError(t, err) + + // Validate expected query parameters present + expected := make(map[string]string) + for k, v := range alwaysExpected { + expected[k] = v + } + if test.identity != nil { + expected[test.identityParameterName] = test.identity.Value + } + + for key := range expected { + value, exists := actual[key] + if assert.Truef(t, exists, "test of %s: query parameter %s was not passed", + test.identityParameterName, key) { + assert.Equalf(t, expected[key], value, + "test of %s: parameter %s has incorrect value", test.identityParameterName, key) + } + } + + // Validate unexpected query parameters absent + for _, key := range test.expectedAbsent { + _, exists := actual[key] + assert.Falsef(t, exists, "query parameter %s was unexpectedly passed") + } + } +} + +func errorHandler(resultCode int) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Test error generated", resultCode) + } +} + +func TestIMDSErrors(t *testing.T) { + errorCodes := []int{404, 429, 500} + for _, code := range errorCodes { + testServer := httptest.NewServer(errorHandler(code)) + defer testServer.Close() + testServerPort, err := strconv.Atoi(strings.Split(testServer.URL, ":")[2]) + require.NoError(t, err) + ctx := context.WithValue(context.TODO(), testPortKey("testPort"), testServerPort) + _, err = GetMSIToken(ctx, nil) + require.Error(t, err) + httpErr, ok := err.(httpError) + require.Truef(t, ok, "HTTP error %d did not result in an httpError object", code) + assert.Equalf(t, httpErr.Response.StatusCode, code, "desired error %d but didn't get it", code) + } +} diff --git a/lib/pacer/pacer_test.go b/lib/pacer/pacer_test.go index 3c17cdc9a..6159acd4e 100644 --- a/lib/pacer/pacer_test.go +++ b/lib/pacer/pacer_test.go @@ -195,6 +195,23 @@ func TestAmazonCloudDrivePacer(t *testing.T) { } } +func TestAzureIMDSPacer(t *testing.T) { + c := NewAzureIMDS() + for _, test := range []struct { + state State + want time.Duration + }{ + {State{SleepTime: 0, ConsecutiveRetries: 0}, 0}, + {State{SleepTime: 0, ConsecutiveRetries: 1}, 2 * time.Second}, + {State{SleepTime: 2 * time.Second, ConsecutiveRetries: 2}, 6 * time.Second}, + {State{SleepTime: 6 * time.Second, ConsecutiveRetries: 3}, 14 * time.Second}, + {State{SleepTime: 14 * time.Second, ConsecutiveRetries: 4}, 30 * time.Second}, + } { + got := c.Calculate(test.state) + assert.Equal(t, test.want, got, "test: %+v", test) + } +} + func TestGoogleDrivePacer(t *testing.T) { // Do lots of times because of the random number! for _, test := range []struct { diff --git a/lib/pacer/pacers.go b/lib/pacer/pacers.go index 9c42be198..d24c10285 100644 --- a/lib/pacer/pacers.go +++ b/lib/pacer/pacers.go @@ -168,6 +168,35 @@ func (c *AmazonCloudDrive) Calculate(state State) time.Duration { return sleepTime } +// AzureIMDS is a pacer for the Azure instance metadata service. +type AzureIMDS struct { +} + +// NewAzureIMDS returns a new Azure IMDS calculator. +func NewAzureIMDS() *AzureIMDS { + c := &AzureIMDS{} + return c +} + +// Calculate takes the current Pacer state and return the wait time until the next try. +func (c *AzureIMDS) Calculate(state State) time.Duration { + var addBackoff time.Duration + + if state.ConsecutiveRetries == 0 { + // Initial condition: no backoff. + return 0 + } + + if state.ConsecutiveRetries > 4 { + // The number of consecutive retries shouldn't exceed five. + // In case it does for some reason, cap delay. + addBackoff = 0 + } else { + addBackoff = time.Duration(2<