Added support for specifying IMDS version preference

This commit is contained in:
Josh Hogle 2020-05-20 13:15:51 -04:00
parent 8c6a46887b
commit 18ac5c07e2
2 changed files with 69 additions and 23 deletions

View file

@ -147,6 +147,7 @@ type AWS struct {
Accounts []string `json:"accounts"` Accounts []string `json:"accounts"`
DisableCustomSANs bool `json:"disableCustomSANs"` DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
IMDSVersions []string `json:"imdsVersions"`
InstanceAge Duration `json:"instanceAge,omitempty"` InstanceAge Duration `json:"instanceAge,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
claimer *Claimer claimer *Claimer
@ -200,14 +201,14 @@ func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) {
var idoc awsInstanceIdentityDocument var idoc awsInstanceIdentityDocument
doc, err := p.readURL(p.config.identityURL) doc, err := p.readURL(p.config.identityURL)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error retrieving identity document, are you in an AWS VM with IMDSv2 enabled?") return "", errors.Wrap(err, "error retrieving identity document, are you in an AWS VM using the proper IMDS version?")
} }
if err := json.Unmarshal(doc, &idoc); err != nil { if err := json.Unmarshal(doc, &idoc); err != nil {
return "", errors.Wrap(err, "error unmarshaling identity document") return "", errors.Wrap(err, "error unmarshaling identity document")
} }
sig, err := p.readURL(p.config.signatureURL) sig, err := p.readURL(p.config.signatureURL)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error retrieving identity document signature, are you in an AWS VM with IMDSv2 enabled?") return "", errors.Wrap(err, "error retrieving identity document signature, are you in an AWS VM using the proper IMDS version?")
} }
signature, err := base64.StdEncoding.DecodeString(string(sig)) signature, err := base64.StdEncoding.DecodeString(string(sig))
if err != nil { if err != nil {
@ -349,43 +350,87 @@ func (p *AWS) checkSignature(signed, signature []byte) error {
// using pkg/errors to avoid verbose errors, the caller should use it and write // using pkg/errors to avoid verbose errors, the caller should use it and write
// the appropriate error. // the appropriate error.
func (p *AWS) readURL(url string) ([]byte, error) { func (p *AWS) readURL(url string) ([]byte, error) {
client := &http.Client{} var resp *http.Response
var err error
// get authorization token for _, v := range p.IMDSVersions {
switch v {
case "v1":
resp, err = p.readURLv1(url)
if err == nil && resp.StatusCode < 400 {
return p.readResponseBody(resp)
}
case "v2":
resp, err = p.readURLv2(url)
if err == nil && resp.StatusCode < 400 {
return p.readResponseBody(resp)
}
default:
return nil, fmt.Errorf("%s: not a supported AWS Instance Metadata Service version", v)
}
}
// all versions have been exhausted and we haven't returned successfully yet so pass
// the error on to the caller
if err != nil {
return nil, err
}
return nil, fmt.Errorf("Request for metadata returned non-successful status code %d",
resp.StatusCode)
}
func (p *AWS) readURLv1(url string) (*http.Response, error) {
client := http.Client{}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
func (p *AWS) readURLv2(url string) (*http.Response, error) {
client := http.Client{}
// first get the token
req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, nil) req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set(awsMetadataTokenTTLHeader, p.config.tokenTTL) req.Header.Set(awsMetadataTokenTTLHeader, p.config.tokenTTL)
r, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer r.Body.Close() defer resp.Body.Close()
if r.StatusCode >= 400 { if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP request returned non-successful status code %d", r.StatusCode) return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode)
} }
b, err := ioutil.ReadAll(r.Body) token, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
token := string(b)
// now get the data // now make the request
req, err = http.NewRequest(http.MethodGet, url, nil) req, err = http.NewRequest(http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set(awsMetadataTokenHeader, token) req.Header.Set(awsMetadataTokenHeader, string(token))
r, err = client.Do(req) resp, err = client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer r.Body.Close() return resp, nil
if r.StatusCode >= 400 { }
return nil, fmt.Errorf("HTTP request returned non-successful status code %d", r.StatusCode)
} func (p *AWS) readResponseBody(resp *http.Response) ([]byte, error) {
b, err = ioutil.ReadAll(r.Body) defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -408,11 +408,12 @@ func generateAWS() (*AWS, error) {
return nil, errors.Wrap(err, "error parsing AWS certificate") return nil, errors.Wrap(err, "error parsing AWS certificate")
} }
return &AWS{ return &AWS{
Type: "AWS", Type: "AWS",
Name: name, Name: name,
Accounts: []string{accountID}, Accounts: []string{accountID},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
claimer: claimer, IMDSVersions: []string{"v2", "v1"},
claimer: claimer,
config: &awsConfig{ config: &awsConfig{
identityURL: awsIdentityURL, identityURL: awsIdentityURL,
signatureURL: awsSignatureURL, signatureURL: awsSignatureURL,