forked from TrueCloudLab/rclone
parent
6342499c47
commit
08b9ede217
5 changed files with 403 additions and 4 deletions
|
@ -102,6 +102,31 @@ for more details.
|
||||||
}, {
|
}, {
|
||||||
Name: "sas_url",
|
Name: "sas_url",
|
||||||
Help: "SAS URL for container level access only\n(leave blank if using account/key or Emulator)",
|
Help: "SAS URL for container level access only\n(leave blank if using account/key or Emulator)",
|
||||||
|
}, {
|
||||||
|
Name: "use_msi",
|
||||||
|
Help: `Use a managed service identity to authenticate (only works in Azure)
|
||||||
|
|
||||||
|
When true, use a [managed service identity](https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/)
|
||||||
|
to authenticate to Azure Storage instead of a SAS token or account key.
|
||||||
|
|
||||||
|
If the VM(SS) on which this program is running has a system-assigned identity, it will
|
||||||
|
be used by default. If the resource has no system-assigned but exactly one user-assigned identity,
|
||||||
|
the user-assigned identity will be used by default. If the resource has multiple user-assigned
|
||||||
|
identities, the identity to use must be explicitly specified using exactly one of the msi_object_id,
|
||||||
|
msi_client_id, or msi_mi_res_id parameters.`,
|
||||||
|
Default: false,
|
||||||
|
}, {
|
||||||
|
Name: "msi_object_id",
|
||||||
|
Help: "Object ID of the user-assigned MSI to use, if any. Leave blank if msi_client_id or msi_mi_res_id specified.",
|
||||||
|
Advanced: true,
|
||||||
|
}, {
|
||||||
|
Name: "msi_client_id",
|
||||||
|
Help: "Object ID of the user-assigned MSI to use, if any. Leave blank if msi_object_id or msi_mi_res_id specified.",
|
||||||
|
Advanced: true,
|
||||||
|
}, {
|
||||||
|
Name: "msi_mi_res_id",
|
||||||
|
Help: "Azure resource ID of the user-assigned MSI to use, if any. Leave blank if msi_client_id or msi_object_id specified.",
|
||||||
|
Advanced: true,
|
||||||
}, {
|
}, {
|
||||||
Name: "use_emulator",
|
Name: "use_emulator",
|
||||||
Help: "Uses local storage emulator if provided as 'true' (leave blank if using real azure storage endpoint)",
|
Help: "Uses local storage emulator if provided as 'true' (leave blank if using real azure storage endpoint)",
|
||||||
|
@ -210,6 +235,10 @@ type Options struct {
|
||||||
Account string `config:"account"`
|
Account string `config:"account"`
|
||||||
ServicePrincipalFile string `config:"service_principal_file"`
|
ServicePrincipalFile string `config:"service_principal_file"`
|
||||||
Key string `config:"key"`
|
Key string `config:"key"`
|
||||||
|
UseMSI bool `config:"use_msi"`
|
||||||
|
MSIObjectID string `config:"msi_object_id"`
|
||||||
|
MSIClientID string `config:"msi_client_id"`
|
||||||
|
MSIResourceID string `config:"msi_mi_res_id"`
|
||||||
Endpoint string `config:"endpoint"`
|
Endpoint string `config:"endpoint"`
|
||||||
SASURL string `config:"sas_url"`
|
SASURL string `config:"sas_url"`
|
||||||
UploadCutoff fs.SizeSuffix `config:"upload_cutoff"`
|
UploadCutoff fs.SizeSuffix `config:"upload_cutoff"`
|
||||||
|
@ -240,6 +269,7 @@ type Fs struct {
|
||||||
isLimited bool // if limited to one container
|
isLimited bool // if limited to one container
|
||||||
cache *bucket.Cache // cache for container creation status
|
cache *bucket.Cache // cache for container creation status
|
||||||
pacer *fs.Pacer // To pace and retry the API calls
|
pacer *fs.Pacer // To pace and retry the API calls
|
||||||
|
imdsPacer *fs.Pacer // Same but for IMDS
|
||||||
uploadToken *pacer.TokenDispenser // control concurrency
|
uploadToken *pacer.TokenDispenser // control concurrency
|
||||||
pool *pool.Pool // memory pool
|
pool *pool.Pool // memory pool
|
||||||
}
|
}
|
||||||
|
@ -342,6 +372,8 @@ func (f *Fs) shouldRetry(err error) (bool, error) {
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if httpErr, ok := err.(httpError); ok {
|
||||||
|
return fserrors.ShouldRetryHTTP(httpErr.Response, retryErrorCodes), err
|
||||||
}
|
}
|
||||||
return fserrors.ShouldRetry(err), err
|
return fserrors.ShouldRetry(err), err
|
||||||
}
|
}
|
||||||
|
@ -502,6 +534,7 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e
|
||||||
opt: *opt,
|
opt: *opt,
|
||||||
ci: ci,
|
ci: ci,
|
||||||
pacer: fs.NewPacer(ctx, pacer.NewS3(pacer.MinSleep(minSleep), pacer.MaxSleep(maxSleep), pacer.DecayConstant(decayConstant))),
|
pacer: fs.NewPacer(ctx, pacer.NewS3(pacer.MinSleep(minSleep), pacer.MaxSleep(maxSleep), pacer.DecayConstant(decayConstant))),
|
||||||
|
imdsPacer: fs.NewPacer(ctx, pacer.NewAzureIMDS()),
|
||||||
uploadToken: pacer.NewTokenDispenser(ci.Transfers),
|
uploadToken: pacer.NewTokenDispenser(ci.Transfers),
|
||||||
client: fshttp.NewClient(ctx),
|
client: fshttp.NewClient(ctx),
|
||||||
cache: bucket.NewCache(),
|
cache: bucket.NewCache(),
|
||||||
|
@ -513,6 +546,7 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e
|
||||||
opt.MemoryPoolUseMmap,
|
opt.MemoryPoolUseMmap,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
f.imdsPacer.SetRetries(5) // per IMDS documentation
|
||||||
f.setRoot(root)
|
f.setRoot(root)
|
||||||
f.features = (&fs.Features{
|
f.features = (&fs.Features{
|
||||||
ReadMimeType: true,
|
ReadMimeType: true,
|
||||||
|
@ -539,16 +573,81 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e
|
||||||
}
|
}
|
||||||
pipeline := f.newPipeline(credential, azblob.PipelineOptions{Retry: azblob.RetryOptions{TryTimeout: maxTryTimeout}})
|
pipeline := f.newPipeline(credential, azblob.PipelineOptions{Retry: azblob.RetryOptions{TryTimeout: maxTryTimeout}})
|
||||||
serviceURL = azblob.NewServiceURL(*u, pipeline)
|
serviceURL = azblob.NewServiceURL(*u, pipeline)
|
||||||
case opt.Account != "" && opt.Key != "":
|
case opt.UseMSI:
|
||||||
credential, err := azblob.NewSharedKeyCredential(opt.Account, opt.Key)
|
var token adal.Token
|
||||||
|
var userMSI *userMSI = &userMSI{}
|
||||||
|
if len(opt.MSIClientID) > 0 || len(opt.MSIObjectID) > 0 || len(opt.MSIResourceID) > 0 {
|
||||||
|
// Specifying a user-assigned identity. Exactly one of the above IDs must be specified.
|
||||||
|
// Validate and ensure exactly one is set. (To do: better validation.)
|
||||||
|
if len(opt.MSIClientID) > 0 {
|
||||||
|
if len(opt.MSIObjectID) > 0 || len(opt.MSIResourceID) > 0 {
|
||||||
|
return nil, errors.New("more than one user-assigned identity ID is set")
|
||||||
|
}
|
||||||
|
userMSI.Type = msiClientID
|
||||||
|
userMSI.Value = opt.MSIClientID
|
||||||
|
}
|
||||||
|
if len(opt.MSIObjectID) > 0 {
|
||||||
|
if len(opt.MSIClientID) > 0 || len(opt.MSIResourceID) > 0 {
|
||||||
|
return nil, errors.New("more than one user-assigned identity ID is set")
|
||||||
|
}
|
||||||
|
userMSI.Type = msiObjectID
|
||||||
|
userMSI.Value = opt.MSIObjectID
|
||||||
|
}
|
||||||
|
if len(opt.MSIResourceID) > 0 {
|
||||||
|
if len(opt.MSIClientID) > 0 || len(opt.MSIObjectID) > 0 {
|
||||||
|
return nil, errors.New("more than one user-assigned identity ID is set")
|
||||||
|
}
|
||||||
|
userMSI.Type = msiResourceID
|
||||||
|
userMSI.Value = opt.MSIResourceID
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
userMSI = nil
|
||||||
|
}
|
||||||
|
err = f.imdsPacer.Call(func() (bool, error) {
|
||||||
|
// Retry as specified by the documentation:
|
||||||
|
// https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance
|
||||||
|
token, err = GetMSIToken(ctx, userMSI)
|
||||||
|
return f.shouldRetry(err)
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "Failed to parse credentials")
|
return nil, errors.Wrapf(err, "Failed to acquire MSI token")
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err = url.Parse(fmt.Sprintf("https://%s.%s", opt.Account, opt.Endpoint))
|
u, err = url.Parse(fmt.Sprintf("https://%s.%s", opt.Account, opt.Endpoint))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to make azure storage url from account and endpoint")
|
return nil, errors.Wrap(err, "failed to make azure storage url from account and endpoint")
|
||||||
}
|
}
|
||||||
|
credential := azblob.NewTokenCredential(token.AccessToken, func(credential azblob.TokenCredential) time.Duration {
|
||||||
|
fs.Debugf(f, "Token refresher called.")
|
||||||
|
var refreshedToken adal.Token
|
||||||
|
err := f.imdsPacer.Call(func() (bool, error) {
|
||||||
|
refreshedToken, err = GetMSIToken(ctx, userMSI)
|
||||||
|
return f.shouldRetry(err)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// Failed to refresh.
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
credential.SetToken(refreshedToken.AccessToken)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
// Refresh one minute before expiry.
|
||||||
|
refreshAt := refreshedToken.Expires().UTC().Add(-1 * time.Minute)
|
||||||
|
fs.Debugf(f, "Acquired new token that expires at %v; refreshing in %d s", refreshedToken.Expires(),
|
||||||
|
int(refreshAt.Sub(now).Seconds()))
|
||||||
|
if now.After(refreshAt) {
|
||||||
|
// Acquired a causality violation.
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return refreshAt.Sub(now)
|
||||||
|
})
|
||||||
|
pipeline := f.newPipeline(credential, azblob.PipelineOptions{Retry: azblob.RetryOptions{TryTimeout: maxTryTimeout}})
|
||||||
|
serviceURL = azblob.NewServiceURL(*u, pipeline)
|
||||||
|
case opt.Account != "" && opt.Key != "":
|
||||||
|
credential, err := azblob.NewSharedKeyCredential(opt.Account, opt.Key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "Failed to parse credentials")
|
||||||
|
}
|
||||||
pipeline := f.newPipeline(credential, azblob.PipelineOptions{Retry: azblob.RetryOptions{TryTimeout: maxTryTimeout}})
|
pipeline := f.newPipeline(credential, azblob.PipelineOptions{Retry: azblob.RetryOptions{TryTimeout: maxTryTimeout}})
|
||||||
serviceURL = azblob.NewServiceURL(*u, pipeline)
|
serviceURL = azblob.NewServiceURL(*u, pipeline)
|
||||||
case opt.SASURL != "":
|
case opt.SASURL != "":
|
||||||
|
@ -590,7 +689,7 @@ func NewFs(ctx context.Context, name, root string, m configmap.Mapper) (fs.Fs, e
|
||||||
pipe := f.newPipeline(azblob.NewTokenCredential("", tokenRefresher), options)
|
pipe := f.newPipeline(azblob.NewTokenCredential("", tokenRefresher), options)
|
||||||
serviceURL = azblob.NewServiceURL(*u, pipe)
|
serviceURL = azblob.NewServiceURL(*u, pipe)
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("Need account+key or connectionString or sasURL or servicePrincipalFile")
|
return nil, errors.New("No authentication method configured")
|
||||||
}
|
}
|
||||||
f.svcURL = &serviceURL
|
f.svcURL = &serviceURL
|
||||||
|
|
||||||
|
|
137
backend/azureblob/imds.go
Normal file
137
backend/azureblob/imds.go
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
// +build !plan9,!solaris,!js,go1.13
|
||||||
|
|
||||||
|
package azureblob
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/Azure/go-autorest/autorest/adal"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rclone/rclone/fs"
|
||||||
|
"github.com/rclone/rclone/fs/fshttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
azureResource = "https://storage.azure.com"
|
||||||
|
imdsAPIVersion = "2018-02-01"
|
||||||
|
msiEndpointDefault = "http://169.254.169.254/metadata/identity/oauth2/token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This custom type is used to add the port the test server has bound to
|
||||||
|
// to the request context.
|
||||||
|
type testPortKey string
|
||||||
|
|
||||||
|
type msiIdentifierType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
msiClientID msiIdentifierType = iota
|
||||||
|
msiObjectID
|
||||||
|
msiResourceID
|
||||||
|
)
|
||||||
|
|
||||||
|
type userMSI struct {
|
||||||
|
Type msiIdentifierType
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpError struct {
|
||||||
|
Response *http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e httpError) Error() string {
|
||||||
|
return fmt.Sprintf("HTTP error %v (%v)", e.Response.StatusCode, e.Response.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMSIToken attempts to obtain an MSI token from the Azure Instance
|
||||||
|
// Metadata Service.
|
||||||
|
func GetMSIToken(ctx context.Context, identity *userMSI) (adal.Token, error) {
|
||||||
|
// Attempt to get an MSI token; silently continue if unsuccessful.
|
||||||
|
// This code has been lovingly stolen from azcopy's OAuthTokenManager.
|
||||||
|
result := adal.Token{}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", msiEndpointDefault, nil)
|
||||||
|
if err != nil {
|
||||||
|
fs.Debugf(nil, "Failed to create request: %v", err)
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
params := req.URL.Query()
|
||||||
|
params.Set("resource", azureResource)
|
||||||
|
params.Set("api-version", imdsAPIVersion)
|
||||||
|
|
||||||
|
// Specify user-assigned identity if requested.
|
||||||
|
if identity != nil {
|
||||||
|
switch identity.Type {
|
||||||
|
case msiClientID:
|
||||||
|
params.Set("client_id", identity.Value)
|
||||||
|
case msiObjectID:
|
||||||
|
params.Set("object_id", identity.Value)
|
||||||
|
case msiResourceID:
|
||||||
|
params.Set("mi_res_id", identity.Value)
|
||||||
|
default:
|
||||||
|
// If this happens, the calling function and this one don't agree on
|
||||||
|
// what valid ID types exist.
|
||||||
|
return result, fmt.Errorf("unknown MSI identity type specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req.URL.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
// The Metadata header is required by all calls to IMDS.
|
||||||
|
req.Header.Set("Metadata", "true")
|
||||||
|
|
||||||
|
// If this function is run in a test, query the test server instead of IMDS.
|
||||||
|
testPort, isTest := ctx.Value(testPortKey("testPort")).(int)
|
||||||
|
if isTest {
|
||||||
|
req.URL.Host = fmt.Sprintf("localhost:%d", testPort)
|
||||||
|
req.Host = req.URL.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
httpClient := fshttp.NewClient(ctx)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return result, errors.Wrap(err, "MSI is not enabled on this VM")
|
||||||
|
}
|
||||||
|
defer func() { // resp and Body should not be nil
|
||||||
|
_, err = io.Copy(ioutil.Discard, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
fs.Debugf(nil, "Unable to drain IMDS response: %v", err)
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
fs.Debugf(nil, "Unable to close IMDS response: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
// Check if the status code indicates success
|
||||||
|
// The request returns 200 currently, add 201 and 202 as well for possible extension.
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case 200, 201, 202:
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
body, _ := ioutil.ReadAll(resp.Body)
|
||||||
|
fs.Errorf(nil, "Couldn't obtain OAuth token from IMDS; server returned status code %d and body: %v", resp.StatusCode, string(body))
|
||||||
|
return result, httpError{Response: resp}
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return result, errors.Wrap(err, "Couldn't read IMDS response")
|
||||||
|
}
|
||||||
|
// Remove BOM, if any. azcopy does this so I'm following along.
|
||||||
|
b = bytes.TrimPrefix(b, []byte("\xef\xbb\xbf"))
|
||||||
|
|
||||||
|
// This would be a good place to persist the token if a large number of rclone
|
||||||
|
// invocations are being made in a short amount of time. If the token is
|
||||||
|
// persisted, the azureblob code will need to check for expiry before every
|
||||||
|
// storage API call.
|
||||||
|
err = json.Unmarshal(b, &result)
|
||||||
|
if err != nil {
|
||||||
|
return result, errors.Wrap(err, "Couldn't unmarshal IMDS response")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
117
backend/azureblob/imds_test.go
Normal file
117
backend/azureblob/imds_test.go
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
// +build !plan9,!solaris,!js,go1.13
|
||||||
|
|
||||||
|
package azureblob
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Azure/go-autorest/autorest/adal"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handler(t *testing.T, actual *map[string]string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := r.ParseForm()
|
||||||
|
require.NoError(t, err)
|
||||||
|
parameters := r.URL.Query()
|
||||||
|
(*actual)["path"] = r.URL.Path
|
||||||
|
(*actual)["Metadata"] = r.Header.Get("Metadata")
|
||||||
|
(*actual)["method"] = r.Method
|
||||||
|
for paramName := range parameters {
|
||||||
|
(*actual)[paramName] = parameters.Get(paramName)
|
||||||
|
}
|
||||||
|
// Make response.
|
||||||
|
response := adal.Token{}
|
||||||
|
responseBytes, err := json.Marshal(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = w.Write(responseBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagedIdentity(t *testing.T) {
|
||||||
|
// test user-assigned identity specifiers to use
|
||||||
|
testMSIClientID := "d859b29f-5c9c-42f8-a327-ec1bc6408d79"
|
||||||
|
testMSIObjectID := "9ffeb650-3ca0-4278-962b-5a38d520591a"
|
||||||
|
testMSIResourceID := "/subscriptions/fe714c49-b8a4-4d49-9388-96a20daa318f/resourceGroups/somerg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/someidentity"
|
||||||
|
tests := []struct {
|
||||||
|
identity *userMSI
|
||||||
|
identityParameterName string
|
||||||
|
expectedAbsent []string
|
||||||
|
}{
|
||||||
|
{&userMSI{msiClientID, testMSIClientID}, "client_id", []string{"object_id", "mi_res_id"}},
|
||||||
|
{&userMSI{msiObjectID, testMSIObjectID}, "object_id", []string{"client_id", "mi_res_id"}},
|
||||||
|
{&userMSI{msiResourceID, testMSIResourceID}, "mi_res_id", []string{"object_id", "client_id"}},
|
||||||
|
{nil, "(default)", []string{"object_id", "client_id", "mi_res_id"}},
|
||||||
|
}
|
||||||
|
alwaysExpected := map[string]string{
|
||||||
|
"path": "/metadata/identity/oauth2/token",
|
||||||
|
"resource": "https://storage.azure.com",
|
||||||
|
"Metadata": "true",
|
||||||
|
"api-version": "2018-02-01",
|
||||||
|
"method": "GET",
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
actual := make(map[string]string, 10)
|
||||||
|
testServer := httptest.NewServer(handler(t, &actual))
|
||||||
|
defer testServer.Close()
|
||||||
|
testServerPort, err := strconv.Atoi(strings.Split(testServer.URL, ":")[2])
|
||||||
|
require.NoError(t, err)
|
||||||
|
ctx := context.WithValue(context.TODO(), testPortKey("testPort"), testServerPort)
|
||||||
|
_, err = GetMSIToken(ctx, test.identity)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Validate expected query parameters present
|
||||||
|
expected := make(map[string]string)
|
||||||
|
for k, v := range alwaysExpected {
|
||||||
|
expected[k] = v
|
||||||
|
}
|
||||||
|
if test.identity != nil {
|
||||||
|
expected[test.identityParameterName] = test.identity.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range expected {
|
||||||
|
value, exists := actual[key]
|
||||||
|
if assert.Truef(t, exists, "test of %s: query parameter %s was not passed",
|
||||||
|
test.identityParameterName, key) {
|
||||||
|
assert.Equalf(t, expected[key], value,
|
||||||
|
"test of %s: parameter %s has incorrect value", test.identityParameterName, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate unexpected query parameters absent
|
||||||
|
for _, key := range test.expectedAbsent {
|
||||||
|
_, exists := actual[key]
|
||||||
|
assert.Falsef(t, exists, "query parameter %s was unexpectedly passed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func errorHandler(resultCode int) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, "Test error generated", resultCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIMDSErrors(t *testing.T) {
|
||||||
|
errorCodes := []int{404, 429, 500}
|
||||||
|
for _, code := range errorCodes {
|
||||||
|
testServer := httptest.NewServer(errorHandler(code))
|
||||||
|
defer testServer.Close()
|
||||||
|
testServerPort, err := strconv.Atoi(strings.Split(testServer.URL, ":")[2])
|
||||||
|
require.NoError(t, err)
|
||||||
|
ctx := context.WithValue(context.TODO(), testPortKey("testPort"), testServerPort)
|
||||||
|
_, err = GetMSIToken(ctx, nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
httpErr, ok := err.(httpError)
|
||||||
|
require.Truef(t, ok, "HTTP error %d did not result in an httpError object", code)
|
||||||
|
assert.Equalf(t, httpErr.Response.StatusCode, code, "desired error %d but didn't get it", code)
|
||||||
|
}
|
||||||
|
}
|
|
@ -195,6 +195,23 @@ func TestAmazonCloudDrivePacer(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAzureIMDSPacer(t *testing.T) {
|
||||||
|
c := NewAzureIMDS()
|
||||||
|
for _, test := range []struct {
|
||||||
|
state State
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{State{SleepTime: 0, ConsecutiveRetries: 0}, 0},
|
||||||
|
{State{SleepTime: 0, ConsecutiveRetries: 1}, 2 * time.Second},
|
||||||
|
{State{SleepTime: 2 * time.Second, ConsecutiveRetries: 2}, 6 * time.Second},
|
||||||
|
{State{SleepTime: 6 * time.Second, ConsecutiveRetries: 3}, 14 * time.Second},
|
||||||
|
{State{SleepTime: 14 * time.Second, ConsecutiveRetries: 4}, 30 * time.Second},
|
||||||
|
} {
|
||||||
|
got := c.Calculate(test.state)
|
||||||
|
assert.Equal(t, test.want, got, "test: %+v", test)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGoogleDrivePacer(t *testing.T) {
|
func TestGoogleDrivePacer(t *testing.T) {
|
||||||
// Do lots of times because of the random number!
|
// Do lots of times because of the random number!
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
|
|
|
@ -168,6 +168,35 @@ func (c *AmazonCloudDrive) Calculate(state State) time.Duration {
|
||||||
return sleepTime
|
return sleepTime
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AzureIMDS is a pacer for the Azure instance metadata service.
|
||||||
|
type AzureIMDS struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAzureIMDS returns a new Azure IMDS calculator.
|
||||||
|
func NewAzureIMDS() *AzureIMDS {
|
||||||
|
c := &AzureIMDS{}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate takes the current Pacer state and return the wait time until the next try.
|
||||||
|
func (c *AzureIMDS) Calculate(state State) time.Duration {
|
||||||
|
var addBackoff time.Duration
|
||||||
|
|
||||||
|
if state.ConsecutiveRetries == 0 {
|
||||||
|
// Initial condition: no backoff.
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.ConsecutiveRetries > 4 {
|
||||||
|
// The number of consecutive retries shouldn't exceed five.
|
||||||
|
// In case it does for some reason, cap delay.
|
||||||
|
addBackoff = 0
|
||||||
|
} else {
|
||||||
|
addBackoff = time.Duration(2<<uint(state.ConsecutiveRetries-1)) * time.Second
|
||||||
|
}
|
||||||
|
return addBackoff + state.SleepTime
|
||||||
|
}
|
||||||
|
|
||||||
// GoogleDrive is a specialized pacer for Google Drive
|
// GoogleDrive is a specialized pacer for Google Drive
|
||||||
//
|
//
|
||||||
// It implements a truncated exponential backoff strategy with randomization.
|
// It implements a truncated exponential backoff strategy with randomization.
|
||||||
|
|
Loading…
Reference in a new issue