Added status code checking

This commit is contained in:
Josh Hogle 2020-05-19 23:57:09 -04:00
parent af0f21d744
commit bbbe4738c7

View file

@ -196,14 +196,14 @@ func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) {
var idoc awsInstanceIdentityDocument var idoc awsInstanceIdentityDocument
doc, err := p.readURL(p.config.identityURL) doc, err := p.readURL(p.config.identityURL)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error retrieving identity document, are you in an AWS VM?") return "", errors.Wrap(err, "error retrieving identity document, are you in an AWS VM with IMDSv2 enabled?")
} }
if err := json.Unmarshal(doc, &idoc); err != nil { if err := json.Unmarshal(doc, &idoc); err != nil {
return "", errors.Wrap(err, "error unmarshaling identity document") return "", errors.Wrap(err, "error unmarshaling identity document")
} }
sig, err := p.readURL(p.config.signatureURL) sig, err := p.readURL(p.config.signatureURL)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error retrieving identity document signature, are you in an AWS VM?") return "", errors.Wrap(err, "error retrieving identity document signature, are you in an AWS VM with IMDSv2 enabled?")
} }
signature, err := base64.StdEncoding.DecodeString(string(sig)) signature, err := base64.StdEncoding.DecodeString(string(sig))
if err != nil { if err != nil {
@ -358,6 +358,9 @@ func (p *AWS) readURL(url string) ([]byte, error) {
return nil, err return nil, err
} }
defer r.Body.Close() defer r.Body.Close()
if r.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP request returned non-successful status code %d", r.StatusCode)
}
b, err := ioutil.ReadAll(r.Body) b, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
return nil, err return nil, err
@ -375,6 +378,9 @@ func (p *AWS) readURL(url string) ([]byte, error) {
return nil, err return nil, err
} }
defer r.Body.Close() defer r.Body.Close()
if r.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP request returned non-successful status code %d", r.StatusCode)
}
b, err = ioutil.ReadAll(r.Body) b, err = ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
return nil, err return nil, err