Added support for specifying IMDS version preference

This commit is contained in:
Josh Hogle 2020-05-20 13:15:51 -04:00
parent 8c6a46887b
commit 18ac5c07e2
2 changed files with 69 additions and 23 deletions

View file

@ -147,6 +147,7 @@ type AWS struct {
Accounts []string `json:"accounts"`
DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
IMDSVersions []string `json:"imdsVersions"`
InstanceAge Duration `json:"instanceAge,omitempty"`
Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
@ -200,14 +201,14 @@ func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) {
var idoc awsInstanceIdentityDocument
doc, err := p.readURL(p.config.identityURL)
if err != nil {
return "", errors.Wrap(err, "error retrieving identity document, are you in an AWS VM with IMDSv2 enabled?")
return "", errors.Wrap(err, "error retrieving identity document, are you in an AWS VM using the proper IMDS version?")
}
if err := json.Unmarshal(doc, &idoc); err != nil {
return "", errors.Wrap(err, "error unmarshaling identity document")
}
sig, err := p.readURL(p.config.signatureURL)
if err != nil {
return "", errors.Wrap(err, "error retrieving identity document signature, are you in an AWS VM with IMDSv2 enabled?")
return "", errors.Wrap(err, "error retrieving identity document signature, are you in an AWS VM using the proper IMDS version?")
}
signature, err := base64.StdEncoding.DecodeString(string(sig))
if err != nil {
@ -349,43 +350,87 @@ func (p *AWS) checkSignature(signed, signature []byte) error {
// using pkg/errors to avoid verbose errors, the caller should use it and write
// the appropriate error.
func (p *AWS) readURL(url string) ([]byte, error) {
client := &http.Client{}
var resp *http.Response
var err error
// get authorization token
for _, v := range p.IMDSVersions {
switch v {
case "v1":
resp, err = p.readURLv1(url)
if err == nil && resp.StatusCode < 400 {
return p.readResponseBody(resp)
}
case "v2":
resp, err = p.readURLv2(url)
if err == nil && resp.StatusCode < 400 {
return p.readResponseBody(resp)
}
default:
return nil, fmt.Errorf("%s: not a supported AWS Instance Metadata Service version", v)
}
}
// all versions have been exhausted and we haven't returned successfully yet so pass
// the error on to the caller
if err != nil {
return nil, err
}
return nil, fmt.Errorf("Request for metadata returned non-successful status code %d",
resp.StatusCode)
}
func (p *AWS) readURLv1(url string) (*http.Response, error) {
client := http.Client{}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
func (p *AWS) readURLv2(url string) (*http.Response, error) {
client := http.Client{}
// first get the token
req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, nil)
if err != nil {
return nil, err
}
req.Header.Set(awsMetadataTokenTTLHeader, p.config.tokenTTL)
r, err := client.Do(req)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer r.Body.Close()
if r.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP request returned non-successful status code %d", r.StatusCode)
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode)
}
b, err := ioutil.ReadAll(r.Body)
token, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
token := string(b)
// now get the data
// now make the request
req, err = http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set(awsMetadataTokenHeader, token)
r, err = client.Do(req)
req.Header.Set(awsMetadataTokenHeader, string(token))
resp, err = client.Do(req)
if err != nil {
return nil, err
}
defer r.Body.Close()
if r.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP request returned non-successful status code %d", r.StatusCode)
}
b, err = ioutil.ReadAll(r.Body)
return resp, nil
}
func (p *AWS) readResponseBody(resp *http.Response) ([]byte, error) {
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}

View file

@ -408,11 +408,12 @@ func generateAWS() (*AWS, error) {
return nil, errors.Wrap(err, "error parsing AWS certificate")
}
return &AWS{
Type: "AWS",
Name: name,
Accounts: []string{accountID},
Claims: &globalProvisionerClaims,
claimer: claimer,
Type: "AWS",
Name: name,
Accounts: []string{accountID},
Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v2", "v1"},
claimer: claimer,
config: &awsConfig{
identityURL: awsIdentityURL,
signatureURL: awsSignatureURL,