forked from TrueCloudLab/rclone
08b9ede217
Fixes #3213
137 lines
3.8 KiB
Go
137 lines
3.8 KiB
Go
// +build !plan9,!solaris,!js,go1.13
|
|
|
|
package azureblob
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
|
|
"github.com/Azure/go-autorest/autorest/adal"
|
|
"github.com/pkg/errors"
|
|
"github.com/rclone/rclone/fs"
|
|
"github.com/rclone/rclone/fs/fshttp"
|
|
)
|
|
|
|
const (
|
|
azureResource = "https://storage.azure.com"
|
|
imdsAPIVersion = "2018-02-01"
|
|
msiEndpointDefault = "http://169.254.169.254/metadata/identity/oauth2/token"
|
|
)
|
|
|
|
// This custom type is used to add the port the test server has bound to
|
|
// to the request context.
|
|
type testPortKey string
|
|
|
|
type msiIdentifierType int
|
|
|
|
const (
|
|
msiClientID msiIdentifierType = iota
|
|
msiObjectID
|
|
msiResourceID
|
|
)
|
|
|
|
type userMSI struct {
|
|
Type msiIdentifierType
|
|
Value string
|
|
}
|
|
|
|
type httpError struct {
|
|
Response *http.Response
|
|
}
|
|
|
|
func (e httpError) Error() string {
|
|
return fmt.Sprintf("HTTP error %v (%v)", e.Response.StatusCode, e.Response.Status)
|
|
}
|
|
|
|
// GetMSIToken attempts to obtain an MSI token from the Azure Instance
|
|
// Metadata Service.
|
|
func GetMSIToken(ctx context.Context, identity *userMSI) (adal.Token, error) {
|
|
// Attempt to get an MSI token; silently continue if unsuccessful.
|
|
// This code has been lovingly stolen from azcopy's OAuthTokenManager.
|
|
result := adal.Token{}
|
|
req, err := http.NewRequestWithContext(ctx, "GET", msiEndpointDefault, nil)
|
|
if err != nil {
|
|
fs.Debugf(nil, "Failed to create request: %v", err)
|
|
return result, err
|
|
}
|
|
params := req.URL.Query()
|
|
params.Set("resource", azureResource)
|
|
params.Set("api-version", imdsAPIVersion)
|
|
|
|
// Specify user-assigned identity if requested.
|
|
if identity != nil {
|
|
switch identity.Type {
|
|
case msiClientID:
|
|
params.Set("client_id", identity.Value)
|
|
case msiObjectID:
|
|
params.Set("object_id", identity.Value)
|
|
case msiResourceID:
|
|
params.Set("mi_res_id", identity.Value)
|
|
default:
|
|
// If this happens, the calling function and this one don't agree on
|
|
// what valid ID types exist.
|
|
return result, fmt.Errorf("unknown MSI identity type specified")
|
|
}
|
|
}
|
|
req.URL.RawQuery = params.Encode()
|
|
|
|
// The Metadata header is required by all calls to IMDS.
|
|
req.Header.Set("Metadata", "true")
|
|
|
|
// If this function is run in a test, query the test server instead of IMDS.
|
|
testPort, isTest := ctx.Value(testPortKey("testPort")).(int)
|
|
if isTest {
|
|
req.URL.Host = fmt.Sprintf("localhost:%d", testPort)
|
|
req.Host = req.URL.Host
|
|
}
|
|
|
|
// Send request
|
|
httpClient := fshttp.NewClient(ctx)
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return result, errors.Wrap(err, "MSI is not enabled on this VM")
|
|
}
|
|
defer func() { // resp and Body should not be nil
|
|
_, err = io.Copy(ioutil.Discard, resp.Body)
|
|
if err != nil {
|
|
fs.Debugf(nil, "Unable to drain IMDS response: %v", err)
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
fs.Debugf(nil, "Unable to close IMDS response: %v", err)
|
|
}
|
|
}()
|
|
// Check if the status code indicates success
|
|
// The request returns 200 currently, add 201 and 202 as well for possible extension.
|
|
switch resp.StatusCode {
|
|
case 200, 201, 202:
|
|
break
|
|
default:
|
|
body, _ := ioutil.ReadAll(resp.Body)
|
|
fs.Errorf(nil, "Couldn't obtain OAuth token from IMDS; server returned status code %d and body: %v", resp.StatusCode, string(body))
|
|
return result, httpError{Response: resp}
|
|
}
|
|
|
|
b, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return result, errors.Wrap(err, "Couldn't read IMDS response")
|
|
}
|
|
// Remove BOM, if any. azcopy does this so I'm following along.
|
|
b = bytes.TrimPrefix(b, []byte("\xef\xbb\xbf"))
|
|
|
|
// This would be a good place to persist the token if a large number of rclone
|
|
// invocations are being made in a short amount of time. If the token is
|
|
// persisted, the azureblob code will need to check for expiry before every
|
|
// storage API call.
|
|
err = json.Unmarshal(b, &result)
|
|
if err != nil {
|
|
return result, errors.Wrap(err, "Couldn't unmarshal IMDS response")
|
|
}
|
|
|
|
return result, nil
|
|
}
|