[release/2.8] vendor: github.com/aws/aws-sdk-go v1.43.16

Signed-off-by: Trevor Wood <Trevor.G.Wood@gmail.com>
This commit is contained in:
Trevor Wood 2022-02-27 13:16:02 -05:00
parent b5ca020cfb
commit 205fdeac51
No known key found for this signature in database
GPG key ID: EF40387E30E40DF1
214 changed files with 66647 additions and 9825 deletions

View file

@ -1,7 +1,7 @@
github.com/Azure/azure-sdk-for-go 4650843026a7fdec254a8d9cf893693a254edd0b github.com/Azure/azure-sdk-for-go 4650843026a7fdec254a8d9cf893693a254edd0b
github.com/Azure/go-autorest eaa7994b2278094c904d31993d26f56324db3052 github.com/Azure/go-autorest eaa7994b2278094c904d31993d26f56324db3052
github.com/sirupsen/logrus 3d4380f53a34dcdc95f0c1db702615992b38d9a4 github.com/sirupsen/logrus 3d4380f53a34dcdc95f0c1db702615992b38d9a4
github.com/aws/aws-sdk-go f831d5a0822a1ad72420ab18c6269bca1ddaf490 github.com/aws/aws-sdk-go e35597a5d580c964527fa56e9bd10aa747759578
github.com/bshuster-repo/logrus-logstash-hook d2c0ecc1836d91814e15e23bb5dc309c3ef51f4a github.com/bshuster-repo/logrus-logstash-hook d2c0ecc1836d91814e15e23bb5dc309c3ef51f4a
github.com/beorn7/perks 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 github.com/beorn7/perks 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
github.com/bugsnag/bugsnag-go b1d153021fcd90ca3f080db36bec96dc690fb274 github.com/bugsnag/bugsnag-go b1d153021fcd90ca3f080db36bec96dc690fb274
@ -12,7 +12,6 @@ github.com/dgrijalva/jwt-go 4bbdd8ac624fc7a9ef7aec841c43d99b5fe65a29 https://git
github.com/docker/go-metrics 399ea8c73916000c64c2c76e8da00ca82f8387ab github.com/docker/go-metrics 399ea8c73916000c64c2c76e8da00ca82f8387ab
github.com/docker/libtrust fa567046d9b14f6aa788882a950d69651d230b21 github.com/docker/libtrust fa567046d9b14f6aa788882a950d69651d230b21
github.com/garyburd/redigo 535138d7bcd717d6531c701ef5933d98b1866257 github.com/garyburd/redigo 535138d7bcd717d6531c701ef5933d98b1866257
github.com/go-ini/ini 2ba15ac2dc9cdf88c110ec2dc0ced7fa45f5678c
github.com/golang/protobuf 8d92cf5fc15a4382f8964b08e1f42a75c0591aa3 github.com/golang/protobuf 8d92cf5fc15a4382f8964b08e1f42a75c0591aa3
github.com/gorilla/handlers 60c7bfde3e33c201519a200a4507a158cc03a17b github.com/gorilla/handlers 60c7bfde3e33c201519a200a4507a158cc03a17b
github.com/gorilla/mux 599cba5e7b6137d46ddf58fb1765f5d928e69604 github.com/gorilla/mux 599cba5e7b6137d46ddf58fb1765f5d928e69604

View file

@ -1,58 +1,159 @@
[![API Reference](http://img.shields.io/badge/api-reference-blue.svg)](http://docs.aws.amazon.com/sdk-for-go/api) [![Join the chat at https://gitter.im/aws/aws-sdk-go](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/aws/aws-sdk-go?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build Status](https://img.shields.io/travis/aws/aws-sdk-go.svg)](https://travis-ci.org/aws/aws-sdk-go) [![Apache V2 License](http://img.shields.io/badge/license-Apache%20V2-blue.svg)](https://github.com/aws/aws-sdk-go/blob/master/LICENSE.txt)
# AWS SDK for Go # AWS SDK for Go
[![API Reference](https://img.shields.io/badge/api-reference-blue.svg)](https://docs.aws.amazon.com/sdk-for-go/api) [![Join the chat at https://gitter.im/aws/aws-sdk-go](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/aws/aws-sdk-go?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build status](https://github.com/aws/aws-sdk-go/actions/workflows/go.yml/badge.svg?branch=main)](https://github.com/aws/aws-sdk-go/actions/workflows/go.yml) [![Apache V2 License](https://img.shields.io/badge/license-Apache%20V2-blue.svg)](https://github.com/aws/aws-sdk-go/blob/main/LICENSE.txt)
aws-sdk-go is the official AWS SDK for the Go programming language. aws-sdk-go is the official AWS SDK for the Go programming language.
Checkout our [release notes](https://github.com/aws/aws-sdk-go/releases) for information about the latest bug fixes, updates, and features added to the SDK. Checkout our [release notes](https://github.com/aws/aws-sdk-go/releases) for
information about the latest bug fixes, updates, and features added to the SDK.
We [announced](https://aws.amazon.com/blogs/developer/aws-sdk-for-go-2-0-developer-preview/) the Developer Preview for the [v2 AWS SDK for Go](https://github.com/aws/aws-sdk-go-v2). The v2 SDK is available at https://github.com/aws/aws-sdk-go-v2, and `go get github.com/aws/aws-sdk-go-v2` via `go get`. Check out the v2 SDK's [changes and updates](https://github.com/aws/aws-sdk-go-v2/blob/master/CHANGELOG.md), and let us know what you think. We want your feedback. We [announced](https://aws.amazon.com/blogs/developer/aws-sdk-for-go-version-2-general-availability/) the General Availability for the [AWS SDK for Go V2 (v2)](https://github.com/aws/aws-sdk-go-v2). The v2 SDK source is available at https://github.com/aws/aws-sdk-go-v2. Review the v2 SDK's [Developer Guide](https://aws.github.io/aws-sdk-go-v2/docs/) to get started with AWS SDK for Go V2 or review the [migration guide](https://aws.github.io/aws-sdk-go-v2/docs/migrating/) if you already use version 1.
## Installing Jump To:
* [Getting Started](#Getting-Started)
* [Quick Examples](#Quick-Examples)
* [Getting Help](#Getting-Help)
* [Contributing](#Contributing)
* [More Resources](#Resources)
If you are using Go 1.5 with the `GO15VENDOREXPERIMENT=1` vendoring flag, or 1.6 and higher you can use the following command to retrieve the SDK. The SDK's non-testing dependencies will be included and are vendored in the `vendor` folder. ## Getting Started
### Installing
Use `go get` to retrieve the SDK to add it to your `GOPATH` workspace, or
project's Go module dependencies.
go get github.com/aws/aws-sdk-go
To update the SDK use `go get -u` to retrieve the latest version of the SDK.
go get -u github.com/aws/aws-sdk-go go get -u github.com/aws/aws-sdk-go
Otherwise if your Go environment does not have vendoring support enabled, or you do not want to include the vendored SDK's dependencies you can use the following command to retrieve the SDK and its non-testing dependencies using `go get`. ### Dependencies
go get -u github.com/aws/aws-sdk-go/aws/... The SDK includes a `vendor` folder containing the runtime dependencies of the
go get -u github.com/aws/aws-sdk-go/service/... SDK. The metadata of the SDK's dependencies can be found in the Go module file
`go.mod` or Dep file `Gopkg.toml`.
If you're looking to retrieve just the SDK without any dependencies use the following command. ### Go Modules
go get -d github.com/aws/aws-sdk-go/ If you are using Go modules, your `go get` will default to the latest tagged
release version of the SDK. To get a specific release version of the SDK use
`@<tag>` in your `go get` command.
These two processes will still include the `vendor` folder and it should be deleted if its not going to be used by your environment. go get github.com/aws/aws-sdk-go@v1.15.77
To get the latest SDK repository change use `@latest`.
go get github.com/aws/aws-sdk-go@latest
### Go 1.5
If you are using Go 1.5 without vendoring enabled, (`GO15VENDOREXPERIMENT=1`),
you will need to use `...` when retrieving the SDK to get its dependencies.
go get github.com/aws/aws-sdk-go/...
This will still include the `vendor` folder. The `vendor` folder can be deleted
if not used by your environment.
rm -rf $GOPATH/src/github.com/aws/aws-sdk-go/vendor rm -rf $GOPATH/src/github.com/aws/aws-sdk-go/vendor
## Getting Help
Please use these community resources for getting help. We use the GitHub issues for tracking bugs and feature requests. ## Quick Examples
* Ask a question on [StackOverflow](http://stackoverflow.com/) and tag it with the [`aws-sdk-go`](http://stackoverflow.com/questions/tagged/aws-sdk-go) tag. ### Complete SDK Example
* Come join the AWS SDK for Go community chat on [gitter](https://gitter.im/aws/aws-sdk-go).
* Open a support ticket with [AWS Support](http://docs.aws.amazon.com/awssupport/latest/user/getting-started.html).
* If you think you may have found a bug, please open an [issue](https://github.com/aws/aws-sdk-go/issues/new).
## Opening Issues This example shows a complete working Go file which will upload a file to S3
and use the Context pattern to implement timeout logic that will cancel the
request if it takes too long. This example highlights how to use sessions,
create a service client, make a request, handle the error, and process the
response.
If you encounter a bug with the AWS SDK for Go we would like to hear about it. Search the [existing issues](https://github.com/aws/aws-sdk-go/issues) and see if others are also experiencing the issue before opening a new issue. Please include the version of AWS SDK for Go, Go language, and OS youre using. Please also include repro case when appropriate. ```go
package main
The GitHub issues are intended for bug reports and feature requests. For help and questions with using AWS SDK for GO please make use of the resources listed in the [Getting Help](https://github.com/aws/aws-sdk-go#getting-help) section. Keeping the list of open issues lean will help us respond in a timely manner. import (
"context"
"flag"
"fmt"
"os"
"time"
## Reference Documentation "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)
[`Getting Started Guide`](https://aws.amazon.com/sdk-for-go/) - This document is a general introduction how to configure and make requests with the SDK. If this is your first time using the SDK, this documentation and the API documentation will help you get started. This document focuses on the syntax and behavior of the SDK. The [Service Developer Guide](https://aws.amazon.com/documentation/) will help you get started using specific AWS services. // Uploads a file to S3 given a bucket and object key. Also takes a duration
// value to terminate the update if it doesn't complete within that time.
//
// The AWS Region needs to be provided in the AWS shared config or on the
// environment variable as `AWS_REGION`. Credentials also must be provided
// Will default to shared config file, but can load from environment if provided.
//
// Usage:
// # Upload myfile.txt to myBucket/myKey. Must complete within 10 minutes or will fail
// go run withContext.go -b mybucket -k myKey -d 10m < myfile.txt
func main() {
var bucket, key string
var timeout time.Duration
[`SDK API Reference Documentation`](https://docs.aws.amazon.com/sdk-for-go/api/) - Use this document to look up all API operation input and output parameters for AWS services supported by the SDK. The API reference also includes documentation of the SDK, and examples how to using the SDK, service client API operations, and API operation require parameters. flag.StringVar(&bucket, "b", "", "Bucket name.")
flag.StringVar(&key, "k", "", "Object key name.")
flag.DurationVar(&timeout, "d", 0, "Upload timeout.")
flag.Parse()
[`Service Developer Guide`](https://aws.amazon.com/documentation/) - Use this documentation to learn how to interface with an AWS service. These are great guides both, if you're getting started with a service, or looking for more information on a service. You should not need this document for coding, though in some cases, services may supply helpful samples that you might want to look out for. // All clients require a Session. The Session provides the client with
// shared configuration such as region, endpoint, and credentials. A
// Session should be shared where possible to take advantage of
// configuration and credential caching. See the session package for
// more information.
sess := session.Must(session.NewSession())
[`SDK Examples`](https://github.com/aws/aws-sdk-go/tree/master/example) - Included in the SDK's repo are a several hand crafted examples using the SDK features and AWS services. // Create a new instance of the service's client with a Session.
// Optional aws.Config values can also be provided as variadic arguments
// to the New function. This option allows you to provide service
// specific configuration.
svc := s3.New(sess)
## Overview of SDK's Packages // Create a context with a timeout that will abort the upload if it takes
// more than the passed in timeout.
ctx := context.Background()
var cancelFn func()
if timeout > 0 {
ctx, cancelFn = context.WithTimeout(ctx, timeout)
}
// Ensure the context is canceled to prevent leaking.
// See context package for more information, https://golang.org/pkg/context/
if cancelFn != nil {
defer cancelFn()
}
// Uploads the object to S3. The Context will interrupt the request if the
// timeout expires.
_, err := svc.PutObjectWithContext(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Body: os.Stdin,
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == request.CanceledErrorCode {
// If the SDK can determine the request or retry delay was canceled
// by a context the CanceledErrorCode error code will be returned.
fmt.Fprintf(os.Stderr, "upload canceled due to timeout, %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "failed to upload object, %v\n", err)
}
os.Exit(1)
}
fmt.Printf("successfully uploaded file to %s/%s\n", bucket, key)
}
```
### Overview of SDK's Packages
The SDK is composed of two main components, SDK core, and service clients. The SDK is composed of two main components, SDK core, and service clients.
The SDK core packages are all available under the aws package at the root of The SDK core packages are all available under the aws package at the root of
@ -90,12 +191,11 @@ package under the service folder at the root of the SDK.
* service - Clients for AWS services. All services supported by the SDK are * service - Clients for AWS services. All services supported by the SDK are
available under this folder. available under this folder.
## How to Use the SDK's AWS Service Clients ### How to Use the SDK's AWS Service Clients
The SDK includes the Go types and utilities you can use to make requests to The SDK includes the Go types and utilities you can use to make requests to
AWS service APIs. Within the service folder at the root of the SDK you'll find AWS service APIs. Within the service folder at the root of the SDK you'll find
a package for each AWS service the SDK supports. All service clients follows a package for each AWS service the SDK supports. All service clients follow common pattern of creation and usage.
a common pattern of creation and usage.
When creating a client for an AWS service you'll first need to have a Session When creating a client for an AWS service you'll first need to have a Session
value constructed. The Session provides shared configuration that can be shared value constructed. The Session provides shared configuration that can be shared
@ -107,7 +207,7 @@ configuration.
Once the service's client is created you can use it to make API requests the Once the service's client is created you can use it to make API requests the
AWS service. These clients are safe to use concurrently. AWS service. These clients are safe to use concurrently.
## Configuring the SDK ### Configuring the SDK
In the AWS SDK for Go, you can configure settings for service clients, such In the AWS SDK for Go, you can configure settings for service clients, such
as the log level and maximum number of retries. Most settings are optional; as the log level and maximum number of retries. Most settings are optional;
@ -237,7 +337,7 @@ options such as setting the Endpoint, and other service client configuration opt
[endpoints_pkg]: https://docs.aws.amazon.com/sdk-for-go/api/aws/endpoints/ [endpoints_pkg]: https://docs.aws.amazon.com/sdk-for-go/api/aws/endpoints/
## Making API Requests ### Making API Requests
Once the client is created you can make an API request to the service. Once the client is created you can make an API request to the service.
Each API method takes a input parameter, and returns the service response Each API method takes a input parameter, and returns the service response
@ -354,98 +454,73 @@ be request.WaiterResourceNotReadyErrorCode.
fmt.Println("Bucket", myBucket, "exists") fmt.Println("Bucket", myBucket, "exists")
``` ```
## Complete SDK Example ## Getting Help
This example shows a complete working Go file which will upload a file to S3 Please use these community resources for getting help. We use the GitHub issues
and use the Context pattern to implement timeout logic that will cancel the for tracking bugs and feature requests.
request if it takes too long. This example highlights how to use sessions,
create a service client, make a request, handle the error, and process the
response.
```go * Ask a question on [StackOverflow](http://stackoverflow.com/) and tag it with the [`aws-sdk-go`](http://stackoverflow.com/questions/tagged/aws-sdk-go) tag.
package main * Come join the AWS SDK for Go community chat on [gitter](https://gitter.im/aws/aws-sdk-go).
* Open a support ticket with [AWS Support](http://docs.aws.amazon.com/awssupport/latest/user/getting-started.html).
* If you think you may have found a bug, please open an [issue](https://github.com/aws/aws-sdk-go/issues/new/choose).
import ( This SDK implements AWS service APIs. For general issues regarding the AWS services and their limitations, you may also take a look at the [Amazon Web Services Discussion Forums](https://forums.aws.amazon.com/).
"context"
"flag"
"fmt"
"os"
"time"
"github.com/aws/aws-sdk-go/aws" ### Opening Issues
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)
// Uploads a file to S3 given a bucket and object key. Also takes a duration If you encounter a bug with the AWS SDK for Go we would like to hear about it.
// value to terminate the update if it doesn't complete within that time. Search the [existing issues](https://github.com/aws/aws-sdk-go/issues) and see
// if others are also experiencing the issue before opening a new issue. Please
// The AWS Region needs to be provided in the AWS shared config or on the include the version of AWS SDK for Go, Go language, and OS youre using. Please
// environment variable as `AWS_REGION`. Credentials also must be provided also include reproduction case when appropriate.
// Will default to shared config file, but can load from environment if provided.
//
// Usage:
// # Upload myfile.txt to myBucket/myKey. Must complete within 10 minutes or will fail
// go run withContext.go -b mybucket -k myKey -d 10m < myfile.txt
func main() {
var bucket, key string
var timeout time.Duration
flag.StringVar(&bucket, "b", "", "Bucket name.") The GitHub issues are intended for bug reports and feature requests. For help
flag.StringVar(&key, "k", "", "Object key name.") and questions with using AWS SDK for Go please make use of the resources listed
flag.DurationVar(&timeout, "d", 0, "Upload timeout.") in the [Getting Help](https://github.com/aws/aws-sdk-go#getting-help) section.
flag.Parse() Keeping the list of open issues lean will help us respond in a timely manner.
// All clients require a Session. The Session provides the client with ## Contributing
// shared configuration such as region, endpoint, and credentials. A
// Session should be shared where possible to take advantage of
// configuration and credential caching. See the session package for
// more information.
sess := session.Must(session.NewSession())
// Create a new instance of the service's client with a Session. We work hard to provide a high-quality and useful SDK for our AWS services, and we greatly value feedback and contributions from our community. Please review our [contributing guidelines](./CONTRIBUTING.md) before submitting any [issues] or [pull requests][pr] to ensure we have all the necessary information to effectively respond to your bug report or contribution.
// Optional aws.Config values can also be provided as variadic arguments
// to the New function. This option allows you to provide service
// specific configuration.
svc := s3.New(sess)
// Create a context with a timeout that will abort the upload if it takes ## Maintenance and support for SDK major versions
// more than the passed in timeout.
ctx := context.Background()
var cancelFn func()
if timeout > 0 {
ctx, cancelFn = context.WithTimeout(ctx, timeout)
}
// Ensure the context is canceled to prevent leaking.
// See context package for more information, https://golang.org/pkg/context/
defer cancelFn()
// Uploads the object to S3. The Context will interrupt the request if the For information about maintenance and support for SDK major versions and our underlying dependencies, see the following in the AWS SDKs and Tools Shared Configuration and Credentials Reference Guide:
// timeout expires.
_, err := svc.PutObjectWithContext(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Body: os.Stdin,
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == request.CanceledErrorCode {
// If the SDK can determine the request or retry delay was canceled
// by a context the CanceledErrorCode error code will be returned.
fmt.Fprintf(os.Stderr, "upload canceled due to timeout, %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "failed to upload object, %v\n", err)
}
os.Exit(1)
}
fmt.Printf("successfully uploaded file to %s/%s\n", bucket, key) * [AWS SDKs and Tools Maintenance Policy](https://docs.aws.amazon.com/credref/latest/refdocs/maint-policy.html)
} * [AWS SDKs and Tools Version Support Matrix](https://docs.aws.amazon.com/credref/latest/refdocs/version-support-matrix.html)
```
## License ## Resources
This SDK is distributed under the [Developer guide](https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/welcome.html) - This document
[Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0), is a general introduction on how to configure and make requests with the SDK.
see LICENSE.txt and NOTICE.txt for more information. If this is your first time using the SDK, this documentation and the API
documentation will help you get started. This document focuses on the syntax
and behavior of the SDK. The [Service Developer Guide](https://aws.amazon.com/documentation/)
will help you get started using specific AWS services.
[SDK API Reference Documentation](https://docs.aws.amazon.com/sdk-for-go/api/) - Use this
document to look up all API operation input and output parameters for AWS
services supported by the SDK. The API reference also includes documentation of
the SDK, and examples how to using the SDK, service client API operations, and
API operation require parameters.
[Service Documentation](https://aws.amazon.com/documentation/) - Use this
documentation to learn how to interface with AWS services. These guides are
great for getting started with a service, or when looking for more
information about a service. While this document is not required for coding,
services may supply helpful samples to look out for.
[SDK Examples](https://github.com/aws/aws-sdk-go/tree/main/example) -
Included in the SDK's repo are several hand crafted examples using the SDK
features and AWS services.
[Forum](https://forums.aws.amazon.com/forum.jspa?forumID=293) - Ask questions, get help, and give feedback
[Issues][issues] - Report issues, submit pull requests, and get involved
(see [Apache 2.0 License][license])
[issues]: https://github.com/aws/aws-sdk-go/issues
[pr]: https://github.com/aws/aws-sdk-go/pulls
[license]: http://aws.amazon.com/apache2.0/

93
vendor/github.com/aws/aws-sdk-go/aws/arn/arn.go generated vendored Normal file
View file

@ -0,0 +1,93 @@
// Package arn provides a parser for interacting with Amazon Resource Names.
package arn
import (
"errors"
"strings"
)
const (
arnDelimiter = ":"
arnSections = 6
arnPrefix = "arn:"
// zero-indexed
sectionPartition = 1
sectionService = 2
sectionRegion = 3
sectionAccountID = 4
sectionResource = 5
// errors
invalidPrefix = "arn: invalid prefix"
invalidSections = "arn: not enough sections"
)
// ARN captures the individual fields of an Amazon Resource Name.
// See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more information.
type ARN struct {
// The partition that the resource is in. For standard AWS regions, the partition is "aws". If you have resources in
// other partitions, the partition is "aws-partitionname". For example, the partition for resources in the China
// (Beijing) region is "aws-cn".
Partition string
// The service namespace that identifies the AWS product (for example, Amazon S3, IAM, or Amazon RDS). For a list of
// namespaces, see
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces.
Service string
// The region the resource resides in. Note that the ARNs for some resources do not require a region, so this
// component might be omitted.
Region string
// The ID of the AWS account that owns the resource, without the hyphens. For example, 123456789012. Note that the
// ARNs for some resources don't require an account number, so this component might be omitted.
AccountID string
// The content of this part of the ARN varies by service. It often includes an indicator of the type of resource —
// for example, an IAM user or Amazon RDS database - followed by a slash (/) or a colon (:), followed by the
// resource name itself. Some services allows paths for resource names, as described in
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#arns-paths.
Resource string
}
// Parse parses an ARN into its constituent parts.
//
// Some example ARNs:
// arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment
// arn:aws:iam::123456789012:user/David
// arn:aws:rds:eu-west-1:123456789012:db:mysql-db
// arn:aws:s3:::my_corporate_bucket/exampleobject.png
func Parse(arn string) (ARN, error) {
if !strings.HasPrefix(arn, arnPrefix) {
return ARN{}, errors.New(invalidPrefix)
}
sections := strings.SplitN(arn, arnDelimiter, arnSections)
if len(sections) != arnSections {
return ARN{}, errors.New(invalidSections)
}
return ARN{
Partition: sections[sectionPartition],
Service: sections[sectionService],
Region: sections[sectionRegion],
AccountID: sections[sectionAccountID],
Resource: sections[sectionResource],
}, nil
}
// IsARN returns whether the given string is an ARN by looking for
// whether the string starts with "arn:" and contains the correct number
// of sections delimited by colons(:).
func IsARN(arn string) bool {
return strings.HasPrefix(arn, arnPrefix) && strings.Count(arn, ":") >= arnSections-1
}
// String returns the canonical representation of the ARN
func (arn ARN) String() string {
return arnPrefix +
arn.Partition + arnDelimiter +
arn.Service + arnDelimiter +
arn.Region + arnDelimiter +
arn.AccountID + arnDelimiter +
arn.Resource
}

View file

@ -138,8 +138,27 @@ type RequestFailure interface {
RequestID() string RequestID() string
} }
// NewRequestFailure returns a new request error wrapper for the given Error // NewRequestFailure returns a wrapped error with additional information for
// provided. // request status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
// that may be meaningful.
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure { func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
return newRequestError(err, statusCode, reqID) return newRequestError(err, statusCode, reqID)
} }
// UnmarshalError provides the interface for the SDK failing to unmarshal data.
type UnmarshalError interface {
awsError
Bytes() []byte
}
// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding
// the bytes that fail to unmarshal to the error.
func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError {
return &unmarshalError{
awsError: New("UnmarshalError", msg, err),
bytes: bytes,
}
}

View file

@ -1,6 +1,9 @@
package awserr package awserr
import "fmt" import (
"encoding/hex"
"fmt"
)
// SprintError returns a string of the formatted error code. // SprintError returns a string of the formatted error code.
// //
@ -119,6 +122,7 @@ type requestError struct {
awsError awsError
statusCode int statusCode int
requestID string requestID string
bytes []byte
} }
// newRequestError returns a wrapped error with additional information for // newRequestError returns a wrapped error with additional information for
@ -170,6 +174,29 @@ func (r requestError) OrigErrs() []error {
return []error{r.OrigErr()} return []error{r.OrigErr()}
} }
type unmarshalError struct {
awsError
bytes []byte
}
// Error returns the string representation of the error.
// Satisfies the error interface.
func (e unmarshalError) Error() string {
extra := hex.Dump(e.bytes)
return SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (e unmarshalError) String() string {
return e.Error()
}
// Bytes returns the bytes that failed to unmarshal.
func (e unmarshalError) Bytes() []byte {
return e.bytes
}
// An error list that satisfies the golang interface // An error list that satisfies the golang interface
type errorList []error type errorList []error
@ -181,7 +208,7 @@ func (e errorList) Error() string {
// How do we want to handle the array size being zero // How do we want to handle the array size being zero
if size := len(e); size > 0 { if size := len(e); size > 0 {
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
msg += fmt.Sprintf("%s", e[i].Error()) msg += e[i].Error()
// We check the next index to see if it is within the slice. // We check the next index to see if it is within the slice.
// If it is, then we append a newline. We do this, because unit tests // If it is, then we append a newline. We do this, because unit tests
// could be broken with the additional '\n' // could be broken with the additional '\n'

View file

@ -15,7 +15,7 @@ func DeepEqual(a, b interface{}) bool {
rb := reflect.Indirect(reflect.ValueOf(b)) rb := reflect.Indirect(reflect.ValueOf(b))
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid { if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
// If the elements are both nil, and of the same type the are equal // If the elements are both nil, and of the same type they are equal
// If they are of different types they are not equal // If they are of different types they are not equal
return reflect.TypeOf(a) == reflect.TypeOf(b) return reflect.TypeOf(a) == reflect.TypeOf(b)
} else if raValid != rbValid { } else if raValid != rbValid {

View file

@ -70,7 +70,7 @@ func rValuesAtPath(v interface{}, path string, createPath, caseSensitive, nilTer
value = value.FieldByNameFunc(func(name string) bool { value = value.FieldByNameFunc(func(name string) bool {
if c == name { if c == name {
return true return true
} else if !caseSensitive && strings.ToLower(name) == strings.ToLower(c) { } else if !caseSensitive && strings.EqualFold(name, c) {
return true return true
} }
return false return false
@ -185,14 +185,13 @@ func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
// SetValueAtPath sets a value at the case insensitive lexical path inside // SetValueAtPath sets a value at the case insensitive lexical path inside
// of a structure. // of a structure.
func SetValueAtPath(i interface{}, path string, v interface{}) { func SetValueAtPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, false, v == nil); rvals != nil { rvals := rValuesAtPath(i, path, true, false, v == nil)
for _, rval := range rvals { for _, rval := range rvals {
if rval.Kind() == reflect.Ptr && rval.IsNil() { if rval.Kind() == reflect.Ptr && rval.IsNil() {
continue continue
} }
setValue(rval, v) setValue(rval, v)
} }
}
} }
func setValue(dstVal reflect.Value, src interface{}) { func setValue(dstVal reflect.Value, src interface{}) {

View file

@ -50,9 +50,19 @@ func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
for i, n := range names { for i, n := range names {
val := v.FieldByName(n) val := v.FieldByName(n)
ft, ok := v.Type().FieldByName(n)
if !ok {
panic(fmt.Sprintf("expected to find field %v on type %v, but was not found", n, v.Type()))
}
buf.WriteString(strings.Repeat(" ", indent+2)) buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ") buf.WriteString(n + ": ")
if tag := ft.Tag.Get("sensitive"); tag == "true" {
buf.WriteString("<sensitive>")
} else {
prettify(val, indent+2, buf) prettify(val, indent+2, buf)
}
if i < len(names)-1 { if i < len(names)-1 {
buf.WriteString(",\n") buf.WriteString(",\n")

View file

@ -8,6 +8,8 @@ import (
) )
// StringValue returns the string representation of a value. // StringValue returns the string representation of a value.
//
// Deprecated: Use Prettify instead.
func StringValue(i interface{}) string { func StringValue(i interface{}) string {
var buf bytes.Buffer var buf bytes.Buffer
stringValue(reflect.ValueOf(i), 0, &buf) stringValue(reflect.ValueOf(i), 0, &buf)
@ -23,28 +25,27 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
case reflect.Struct: case reflect.Struct:
buf.WriteString("{\n") buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ { for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name ft := v.Type().Field(i)
f := v.Field(i) fv := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
if ft.Name[0:1] == strings.ToLower(ft.Name[0:1]) {
continue // ignore unexported fields continue // ignore unexported fields
} }
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() { if (fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Slice) && fv.IsNil() {
continue // ignore unset fields continue // ignore unset fields
} }
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2)) buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ") buf.WriteString(ft.Name + ": ")
stringValue(val, indent+2, buf)
if i < len(names)-1 { if tag := ft.Tag.Get("sensitive"); tag == "true" {
buf.WriteString(",\n") buf.WriteString("<sensitive>")
} else {
stringValue(fv, indent+2, buf)
} }
buf.WriteString(",\n")
} }
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}") buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")

View file

@ -12,13 +12,15 @@ import (
type Config struct { type Config struct {
Config *aws.Config Config *aws.Config
Handlers request.Handlers Handlers request.Handlers
PartitionID string
Endpoint string Endpoint string
SigningRegion string SigningRegion string
SigningName string SigningName string
ResolvedRegion string
// States that the signing name did not come from a modeled source but // States that the signing name did not come from a modeled source but
// was derived based on other data. Used by service client constructors // was derived based on other data. Used by service client constructors
// to determine if the signin name can be overriden based on metadata the // to determine if the signin name can be overridden based on metadata the
// service has. // service has.
SigningNameDerived bool SigningNameDerived bool
} }
@ -64,7 +66,7 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op
default: default:
maxRetries := aws.IntValue(cfg.MaxRetries) maxRetries := aws.IntValue(cfg.MaxRetries)
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries { if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
maxRetries = 3 maxRetries = DefaultRetryerMaxNumRetries
} }
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries} svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
} }
@ -87,10 +89,6 @@ func (c *Client) NewRequest(operation *request.Operation, params interface{}, da
// AddDebugHandlers injects debug logging handlers into the service to log request // AddDebugHandlers injects debug logging handlers into the service to log request
// debug information. // debug information.
func (c *Client) AddDebugHandlers() { func (c *Client) AddDebugHandlers() {
if !c.Config.LogLevel.AtLeast(aws.LogDebug) {
return
}
c.Handlers.Send.PushFrontNamed(LogHTTPRequestHandler) c.Handlers.Send.PushFrontNamed(LogHTTPRequestHandler)
c.Handlers.Send.PushBackNamed(LogHTTPResponseHandler) c.Handlers.Send.PushBackNamed(LogHTTPResponseHandler)
} }

View file

@ -1,6 +1,7 @@
package client package client
import ( import (
"math"
"strconv" "strconv"
"time" "time"
@ -9,82 +10,142 @@ import (
) )
// DefaultRetryer implements basic retry logic using exponential backoff for // DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, implement the // most services. If you want to implement custom retry logic, you can implement the
// request.Retryer interface or create a structure type that composes this // request.Retryer interface.
// struct and override the specific methods. For example, to override only
// the MaxRetries method:
// //
// type retryer struct {
// client.DefaultRetryer
// }
//
// // This implementation always has 100 max retries
// func (d retryer) MaxRetries() int { return 100 }
type DefaultRetryer struct { type DefaultRetryer struct {
// Num max Retries is the number of max retries that will be performed.
// By default, this is zero.
NumMaxRetries int NumMaxRetries int
// MinRetryDelay is the minimum retry delay after which retry will be performed.
// If not set, the value is 0ns.
MinRetryDelay time.Duration
// MinThrottleRetryDelay is the minimum retry delay when throttled.
// If not set, the value is 0ns.
MinThrottleDelay time.Duration
// MaxRetryDelay is the maximum retry delay before which retry must be performed.
// If not set, the value is 0ns.
MaxRetryDelay time.Duration
// MaxThrottleDelay is the maximum retry delay when throttled.
// If not set, the value is 0ns.
MaxThrottleDelay time.Duration
} }
const (
// DefaultRetryerMaxNumRetries sets maximum number of retries
DefaultRetryerMaxNumRetries = 3
// DefaultRetryerMinRetryDelay sets minimum retry delay
DefaultRetryerMinRetryDelay = 30 * time.Millisecond
// DefaultRetryerMinThrottleDelay sets minimum delay when throttled
DefaultRetryerMinThrottleDelay = 500 * time.Millisecond
// DefaultRetryerMaxRetryDelay sets maximum retry delay
DefaultRetryerMaxRetryDelay = 300 * time.Second
// DefaultRetryerMaxThrottleDelay sets maximum delay when throttled
DefaultRetryerMaxThrottleDelay = 300 * time.Second
)
// MaxRetries returns the number of maximum returns the service will use to make // MaxRetries returns the number of maximum returns the service will use to make
// an individual API request. // an individual API request.
func (d DefaultRetryer) MaxRetries() int { func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries return d.NumMaxRetries
} }
// setRetryerDefaults sets the default values of the retryer if not set
func (d *DefaultRetryer) setRetryerDefaults() {
if d.MinRetryDelay == 0 {
d.MinRetryDelay = DefaultRetryerMinRetryDelay
}
if d.MaxRetryDelay == 0 {
d.MaxRetryDelay = DefaultRetryerMaxRetryDelay
}
if d.MinThrottleDelay == 0 {
d.MinThrottleDelay = DefaultRetryerMinThrottleDelay
}
if d.MaxThrottleDelay == 0 {
d.MaxThrottleDelay = DefaultRetryerMaxThrottleDelay
}
}
// RetryRules returns the delay duration before retrying this request again // RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration { func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
// Set the upper limit of delay in retrying at ~five minutes
minTime := 30 // if number of max retries is zero, no retries will be performed.
throttle := d.shouldThrottle(r) if d.NumMaxRetries == 0 {
if throttle { return 0
if delay, ok := getRetryDelay(r); ok {
return delay
} }
minTime = 500 // Sets default value for retryer members
d.setRetryerDefaults()
// minDelay is the minimum retryer delay
minDelay := d.MinRetryDelay
var initialDelay time.Duration
isThrottle := r.IsErrorThrottle()
if isThrottle {
if delay, ok := getRetryAfterDelay(r); ok {
initialDelay = delay
}
minDelay = d.MinThrottleDelay
} }
retryCount := r.RetryCount retryCount := r.RetryCount
if throttle && retryCount > 8 {
retryCount = 8 // maxDelay the maximum retryer delay
} else if retryCount > 13 { maxDelay := d.MaxRetryDelay
retryCount = 13
if isThrottle {
maxDelay = d.MaxThrottleDelay
} }
delay := (1 << uint(retryCount)) * (sdkrand.SeededRand.Intn(minTime) + minTime) var delay time.Duration
return time.Duration(delay) * time.Millisecond
// Logic to cap the retry count based on the minDelay provided
actualRetryCount := int(math.Log2(float64(minDelay))) + 1
if actualRetryCount < 63-retryCount {
delay = time.Duration(1<<uint64(retryCount)) * getJitterDelay(minDelay)
if delay > maxDelay {
delay = getJitterDelay(maxDelay / 2)
}
} else {
delay = getJitterDelay(maxDelay / 2)
}
return delay + initialDelay
}
// getJitterDelay returns a jittered delay for retry
func getJitterDelay(duration time.Duration) time.Duration {
return time.Duration(sdkrand.SeededRand.Int63n(int64(duration)) + int64(duration))
} }
// ShouldRetry returns true if the request should be retried. // ShouldRetry returns true if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool { func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
// ShouldRetry returns false if number of max retries is 0.
if d.NumMaxRetries == 0 {
return false
}
// If one of the other handlers already set the retry state // If one of the other handlers already set the retry state
// we don't want to override it based on the service's state // we don't want to override it based on the service's state
if r.Retryable != nil { if r.Retryable != nil {
return *r.Retryable return *r.Retryable
} }
return r.IsErrorRetryable() || r.IsErrorThrottle()
if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 {
return true
}
return r.IsErrorRetryable() || d.shouldThrottle(r)
}
// ShouldThrottle returns true if the request should be throttled.
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
switch r.HTTPResponse.StatusCode {
case 429:
case 502:
case 503:
case 504:
default:
return r.IsErrorThrottle()
}
return true
} }
// This will look in the Retry-After header, RFC 7231, for how long // This will look in the Retry-After header, RFC 7231, for how long
// it will wait before attempting another request // it will wait before attempting another request
func getRetryDelay(r *request.Request) (time.Duration, bool) { func getRetryAfterDelay(r *request.Request) (time.Duration, bool) {
if !canUseRetryAfterHeader(r) { if !canUseRetryAfterHeader(r) {
return 0, false return 0, false
} }

View file

@ -53,6 +53,10 @@ var LogHTTPRequestHandler = request.NamedHandler{
} }
func logRequest(r *request.Request) { func logRequest(r *request.Request) {
if !r.Config.LogLevel.AtLeast(aws.LogDebug) || r.Config.Logger == nil {
return
}
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
bodySeekable := aws.IsReaderSeekable(r.Body) bodySeekable := aws.IsReaderSeekable(r.Body)
@ -67,10 +71,14 @@ func logRequest(r *request.Request) {
if !bodySeekable { if !bodySeekable {
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body)) r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
} }
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's // Reset the request body because dumpRequest will re-wrap the
// Body as a NoOpCloser and will not be reset after read by the HTTP // r.HTTPRequest's Body as a NoOpCloser and will not be reset after
// client reader. // read by the HTTP client reader.
r.ResetBody() if err := r.Error; err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
} }
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
@ -86,6 +94,10 @@ var LogHTTPRequestHeaderHandler = request.NamedHandler{
} }
func logRequestHeader(r *request.Request) { func logRequestHeader(r *request.Request) {
if !r.Config.LogLevel.AtLeast(aws.LogDebug) || r.Config.Logger == nil {
return
}
b, err := httputil.DumpRequestOut(r.HTTPRequest, false) b, err := httputil.DumpRequestOut(r.HTTPRequest, false)
if err != nil { if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
@ -116,8 +128,18 @@ var LogHTTPResponseHandler = request.NamedHandler{
} }
func logResponse(r *request.Request) { func logResponse(r *request.Request) {
if !r.Config.LogLevel.AtLeast(aws.LogDebug) || r.Config.Logger == nil {
return
}
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)} lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
if r.HTTPResponse == nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, "request's HTTPResponse is nil"))
return
}
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
if logBody { if logBody {
r.HTTPResponse.Body = &teeReaderCloser{ r.HTTPResponse.Body = &teeReaderCloser{
@ -168,7 +190,7 @@ var LogHTTPResponseHeaderHandler = request.NamedHandler{
} }
func logResponseHeader(r *request.Request) { func logResponseHeader(r *request.Request) {
if r.Config.Logger == nil { if !r.Config.LogLevel.AtLeast(aws.LogDebug) || r.Config.Logger == nil {
return return
} }

View file

@ -5,9 +5,11 @@ type ClientInfo struct {
ServiceName string ServiceName string
ServiceID string ServiceID string
APIVersion string APIVersion string
PartitionID string
Endpoint string Endpoint string
SigningName string SigningName string
SigningRegion string SigningRegion string
JSONVersion string JSONVersion string
TargetPrefix string TargetPrefix string
ResolvedRegion string
} }

View file

@ -0,0 +1,28 @@
package client
import (
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
// NoOpRetryer provides a retryer that performs no retries.
// It should be used when we do not want retries to be performed.
type NoOpRetryer struct{}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
func (d NoOpRetryer) MaxRetries() int {
return 0
}
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
func (d NoOpRetryer) ShouldRetry(_ *request.Request) bool {
return false
}
// RetryRules returns the delay duration before retrying this request again;
// since NoOpRetryer does not retry, RetryRules always returns 0.
func (d NoOpRetryer) RetryRules(_ *request.Request) time.Duration {
return 0
}

View file

@ -18,9 +18,9 @@ const UseServiceDefaultRetries = -1
type RequestRetryer interface{} type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default, // A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig tructure. // all clients will use the defaults.DefaultConfig structure.
// //
// // Create Session with MaxRetry configuration to be shared by multiple // // Create Session with MaxRetries configuration to be shared by multiple
// // service clients. // // service clients.
// sess := session.Must(session.NewSession(&aws.Config{ // sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3), // MaxRetries: aws.Int(3),
@ -43,9 +43,9 @@ type Config struct {
// An optional endpoint URL (hostname only or fully qualified URI) // An optional endpoint URL (hostname only or fully qualified URI)
// that overrides the default generated endpoint for a client. Set this // that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint. // to `nil` or the value to `""` to use the default generated endpoint.
// //
// @note You must still provide a `Region` value when specifying an // Note: You must still provide a `Region` value when specifying an
// endpoint for a client. // endpoint for a client.
Endpoint *string Endpoint *string
@ -65,8 +65,8 @@ type Config struct {
// noted. A full list of regions is found in the "Regions and Endpoints" // noted. A full list of regions is found in the "Regions and Endpoints"
// document. // document.
// //
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html // See http://docs.aws.amazon.com/general/latest/gr/rande.html for AWS
// AWS Regions and Endpoints // Regions and Endpoints.
Region *string Region *string
// Set this to `true` to disable SSL when sending requests. Defaults // Set this to `true` to disable SSL when sending requests. Defaults
@ -120,9 +120,10 @@ type Config struct {
// will use virtual hosted bucket addressing when possible // will use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`). // (`http://BUCKET.s3.amazonaws.com/KEY`).
// //
// @note This configuration option is specific to the Amazon S3 service. // Note: This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html //
// Amazon S3: Virtual Hosting of Buckets // See http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// for Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool S3ForcePathStyle *bool
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue` // Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
@ -137,7 +138,7 @@ type Config struct {
// `ExpectContinueTimeout` for information on adjusting the continue wait // `ExpectContinueTimeout` for information on adjusting the continue wait
// timeout. https://golang.org/pkg/net/http/#Transport // timeout. https://golang.org/pkg/net/http/#Transport
// //
// You should use this flag to disble 100-Continue if you experience issues // You should use this flag to disable 100-Continue if you experience issues
// with proxies or third party S3 compatible services. // with proxies or third party S3 compatible services.
S3Disable100Continue *bool S3Disable100Continue *bool
@ -160,6 +161,20 @@ type Config struct {
// on GetObject API calls. // on GetObject API calls.
S3DisableContentMD5Validation *bool S3DisableContentMD5Validation *bool
// Set this to `true` to have the S3 service client to use the region specified
// in the ARN, when an ARN is provided as an argument to a bucket parameter.
S3UseARNRegion *bool
// Set this to `true` to enable the SDK to unmarshal API response header maps to
// normalized lower case map keys.
//
// For example S3's X-Amz-Meta prefixed header will be unmarshaled to lower case
// Metadata member's map keys. The value of the header in the map is unaffected.
//
// The AWS SDK for Go v2, uses lower case header maps by default. The v1
// SDK provides this opt-in for this option, for backwards compatibility.
LowerCaseHeaderMaps *bool
// Set this to `true` to disable the EC2Metadata client from overriding the // Set this to `true` to disable the EC2Metadata client from overriding the
// default http.Client's Timeout. This is helpful if you do not want the // default http.Client's Timeout. This is helpful if you do not want the
// EC2Metadata client to create a new http.Client. This options is only // EC2Metadata client to create a new http.Client. This options is only
@ -171,7 +186,7 @@ type Config struct {
// //
// Example: // Example:
// sess := session.Must(session.NewSession(aws.NewConfig() // sess := session.Must(session.NewSession(aws.NewConfig()
// .WithEC2MetadataDiableTimeoutOverride(true))) // .WithEC2MetadataDisableTimeoutOverride(true)))
// //
// svc := s3.New(sess) // svc := s3.New(sess)
// //
@ -182,7 +197,7 @@ type Config struct {
// both IPv4 and IPv6 addressing. // both IPv4 and IPv6 addressing.
// //
// Setting this for a service which does not support dual stack will fail // Setting this for a service which does not support dual stack will fail
// to make requets. It is not recommended to set this value on the session // to make requests. It is not recommended to set this value on the session
// as it will apply to all service clients created with the session. Even // as it will apply to all service clients created with the session. Even
// services which don't support dual stack endpoints. // services which don't support dual stack endpoints.
// //
@ -196,8 +211,19 @@ type Config struct {
// svc := s3.New(sess, &aws.Config{ // svc := s3.New(sess, &aws.Config{
// UseDualStack: aws.Bool(true), // UseDualStack: aws.Bool(true),
// }) // })
//
// Deprecated: This option will continue to function for S3 and S3 Control for backwards compatibility.
// UseDualStackEndpoint should be used to enable usage of a service's dual-stack endpoint for all service clients
// moving forward. For S3 and S3 Control, when UseDualStackEndpoint is set to a non-zero value it takes higher
// precedence then this option.
UseDualStack *bool UseDualStack *bool
// Sets the resolver to resolve a dual-stack endpoint for the service.
UseDualStackEndpoint endpoints.DualStackEndpointState
// UseFIPSEndpoint specifies the resolver must resolve a FIPS endpoint.
UseFIPSEndpoint endpoints.FIPSEndpointState
// SleepDelay is an override for the func the SDK will call when sleeping // SleepDelay is an override for the func the SDK will call when sleeping
// during the lifecycle of a request. Specifically this will be used for // during the lifecycle of a request. Specifically this will be used for
// request delays. This value should only be used for testing. To adjust // request delays. This value should only be used for testing. To adjust
@ -223,12 +249,41 @@ type Config struct {
// Key: aws.String("//foo//bar//moo"), // Key: aws.String("//foo//bar//moo"),
// }) // })
DisableRestProtocolURICleaning *bool DisableRestProtocolURICleaning *bool
// EnableEndpointDiscovery will allow for endpoint discovery on operations that
// have the definition in its model. By default, endpoint discovery is off.
// To use EndpointDiscovery, Endpoint should be unset or set to an empty string.
//
// Example:
// sess := session.Must(session.NewSession(&aws.Config{
// EnableEndpointDiscovery: aws.Bool(true),
// }))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("/foo/bar/moo"),
// })
EnableEndpointDiscovery *bool
// DisableEndpointHostPrefix will disable the SDK's behavior of prefixing
// request endpoint hosts with modeled information.
//
// Disabling this feature is useful when you want to use local endpoints
// for testing that do not support the modeled host prefix pattern.
DisableEndpointHostPrefix *bool
// STSRegionalEndpoint will enable regional or legacy endpoint resolving
STSRegionalEndpoint endpoints.STSRegionalEndpoint
// S3UsEast1RegionalEndpoint will enable regional or legacy endpoint resolving
S3UsEast1RegionalEndpoint endpoints.S3UsEast1RegionalEndpoint
} }
// NewConfig returns a new Config pointer that can be chained with builder // NewConfig returns a new Config pointer that can be chained with builder
// methods to set multiple configuration values inline without using pointers. // methods to set multiple configuration values inline without using pointers.
// //
// // Create Session with MaxRetry configuration to be shared by multiple // // Create Session with MaxRetries configuration to be shared by multiple
// // service clients. // // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig(). // sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3), // WithMaxRetries(3),
@ -356,6 +411,13 @@ func (c *Config) WithS3DisableContentMD5Validation(enable bool) *Config {
} }
// WithS3UseARNRegion sets a config S3UseARNRegion value and
// returning a Config pointer for chaining
func (c *Config) WithS3UseARNRegion(enable bool) *Config {
c.S3UseARNRegion = &enable
return c
}
// WithUseDualStack sets a config UseDualStack value returning a Config // WithUseDualStack sets a config UseDualStack value returning a Config
// pointer for chaining. // pointer for chaining.
func (c *Config) WithUseDualStack(enable bool) *Config { func (c *Config) WithUseDualStack(enable bool) *Config {
@ -377,6 +439,47 @@ func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
return c return c
} }
// WithEndpointDiscovery will set whether or not to use endpoint discovery.
func (c *Config) WithEndpointDiscovery(t bool) *Config {
c.EnableEndpointDiscovery = &t
return c
}
// WithDisableEndpointHostPrefix will set whether or not to use modeled host prefix
// when making requests.
func (c *Config) WithDisableEndpointHostPrefix(t bool) *Config {
c.DisableEndpointHostPrefix = &t
return c
}
// WithSTSRegionalEndpoint will set whether or not to use regional endpoint flag
// when resolving the endpoint for a service
func (c *Config) WithSTSRegionalEndpoint(sre endpoints.STSRegionalEndpoint) *Config {
c.STSRegionalEndpoint = sre
return c
}
// WithS3UsEast1RegionalEndpoint will set whether or not to use regional endpoint flag
// when resolving the endpoint for a service
func (c *Config) WithS3UsEast1RegionalEndpoint(sre endpoints.S3UsEast1RegionalEndpoint) *Config {
c.S3UsEast1RegionalEndpoint = sre
return c
}
// WithLowerCaseHeaderMaps sets a config LowerCaseHeaderMaps value
// returning a Config pointer for chaining.
func (c *Config) WithLowerCaseHeaderMaps(t bool) *Config {
c.LowerCaseHeaderMaps = &t
return c
}
// WithDisableRestProtocolURICleaning sets a config DisableRestProtocolURICleaning value
// returning a Config pointer for chaining.
func (c *Config) WithDisableRestProtocolURICleaning(t bool) *Config {
c.DisableRestProtocolURICleaning = &t
return c
}
// MergeIn merges the passed in configs into the existing config object. // MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) { func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs { for _, other := range cfgs {
@ -457,10 +560,18 @@ func mergeInConfig(dst *Config, other *Config) {
dst.S3DisableContentMD5Validation = other.S3DisableContentMD5Validation dst.S3DisableContentMD5Validation = other.S3DisableContentMD5Validation
} }
if other.S3UseARNRegion != nil {
dst.S3UseARNRegion = other.S3UseARNRegion
}
if other.UseDualStack != nil { if other.UseDualStack != nil {
dst.UseDualStack = other.UseDualStack dst.UseDualStack = other.UseDualStack
} }
if other.UseDualStackEndpoint != endpoints.DualStackEndpointStateUnset {
dst.UseDualStackEndpoint = other.UseDualStackEndpoint
}
if other.EC2MetadataDisableTimeoutOverride != nil { if other.EC2MetadataDisableTimeoutOverride != nil {
dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride
} }
@ -476,6 +587,34 @@ func mergeInConfig(dst *Config, other *Config) {
if other.EnforceShouldRetryCheck != nil { if other.EnforceShouldRetryCheck != nil {
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
} }
if other.EnableEndpointDiscovery != nil {
dst.EnableEndpointDiscovery = other.EnableEndpointDiscovery
}
if other.DisableEndpointHostPrefix != nil {
dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix
}
if other.STSRegionalEndpoint != endpoints.UnsetSTSEndpoint {
dst.STSRegionalEndpoint = other.STSRegionalEndpoint
}
if other.S3UsEast1RegionalEndpoint != endpoints.UnsetS3UsEast1Endpoint {
dst.S3UsEast1RegionalEndpoint = other.S3UsEast1RegionalEndpoint
}
if other.LowerCaseHeaderMaps != nil {
dst.LowerCaseHeaderMaps = other.LowerCaseHeaderMaps
}
if other.UseDualStackEndpoint != endpoints.DualStackEndpointStateUnset {
dst.UseDualStackEndpoint = other.UseDualStackEndpoint
}
if other.UseFIPSEndpoint != endpoints.FIPSEndpointStateUnset {
dst.UseFIPSEndpoint = other.UseFIPSEndpoint
}
} }
// Copy will return a shallow copy of the Config object. If any additional // Copy will return a shallow copy of the Config object. If any additional

View file

@ -1,8 +1,9 @@
//go:build !go1.9
// +build !go1.9
package aws package aws
import ( import "time"
"time"
)
// Context is an copy of the Go v1.7 stdlib's context.Context interface. // Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext" // It is represented as a SDK interface to enable you to use the "WithContext"
@ -35,37 +36,3 @@ type Context interface {
// functions. // functions.
Value(key interface{}) interface{} Value(key interface{}) interface{}
} }
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

View file

@ -1,9 +0,0 @@
// +build go1.7
package aws
import "context"
var (
backgroundCtx = context.Background()
)

12
vendor/github.com/aws/aws-sdk-go/aws/context_1_9.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
//go:build go1.9
// +build go1.9
package aws
import "context"
// Context is an alias of the Go stdlib's context.Context interface.
// It can be used within the SDK's API operation "WithContext" methods.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context = context.Context

View file

@ -0,0 +1,23 @@
//go:build !go1.7
// +build !go1.7
package aws
import (
"github.com/aws/aws-sdk-go/internal/context"
)
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return context.BackgroundCtx
}

View file

@ -0,0 +1,21 @@
//go:build go1.7
// +build go1.7
package aws
import "context"
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return context.Background()
}

24
vendor/github.com/aws/aws-sdk-go/aws/context_sleep.go generated vendored Normal file
View file

@ -0,0 +1,24 @@
package aws
import (
"time"
)
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

View file

@ -179,6 +179,242 @@ func IntValueMap(src map[string]*int) map[string]int {
return dst return dst
} }
// Uint returns a pointer to the uint value passed in.
func Uint(v uint) *uint {
return &v
}
// UintValue returns the value of the uint pointer passed in or
// 0 if the pointer is nil.
func UintValue(v *uint) uint {
if v != nil {
return *v
}
return 0
}
// UintSlice converts a slice of uint values uinto a slice of
// uint pointers
func UintSlice(src []uint) []*uint {
dst := make([]*uint, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// UintValueSlice converts a slice of uint pointers uinto a slice of
// uint values
func UintValueSlice(src []*uint) []uint {
dst := make([]uint, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// UintMap converts a string map of uint values uinto a string
// map of uint pointers
func UintMap(src map[string]uint) map[string]*uint {
dst := make(map[string]*uint)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// UintValueMap converts a string map of uint pointers uinto a string
// map of uint values
func UintValueMap(src map[string]*uint) map[string]uint {
dst := make(map[string]uint)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int8 returns a pointer to the int8 value passed in.
func Int8(v int8) *int8 {
return &v
}
// Int8Value returns the value of the int8 pointer passed in or
// 0 if the pointer is nil.
func Int8Value(v *int8) int8 {
if v != nil {
return *v
}
return 0
}
// Int8Slice converts a slice of int8 values into a slice of
// int8 pointers
func Int8Slice(src []int8) []*int8 {
dst := make([]*int8, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int8ValueSlice converts a slice of int8 pointers into a slice of
// int8 values
func Int8ValueSlice(src []*int8) []int8 {
dst := make([]int8, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int8Map converts a string map of int8 values into a string
// map of int8 pointers
func Int8Map(src map[string]int8) map[string]*int8 {
dst := make(map[string]*int8)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int8ValueMap converts a string map of int8 pointers into a string
// map of int8 values
func Int8ValueMap(src map[string]*int8) map[string]int8 {
dst := make(map[string]int8)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int16 returns a pointer to the int16 value passed in.
func Int16(v int16) *int16 {
return &v
}
// Int16Value returns the value of the int16 pointer passed in or
// 0 if the pointer is nil.
func Int16Value(v *int16) int16 {
if v != nil {
return *v
}
return 0
}
// Int16Slice converts a slice of int16 values into a slice of
// int16 pointers
func Int16Slice(src []int16) []*int16 {
dst := make([]*int16, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int16ValueSlice converts a slice of int16 pointers into a slice of
// int16 values
func Int16ValueSlice(src []*int16) []int16 {
dst := make([]int16, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int16Map converts a string map of int16 values into a string
// map of int16 pointers
func Int16Map(src map[string]int16) map[string]*int16 {
dst := make(map[string]*int16)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int16ValueMap converts a string map of int16 pointers into a string
// map of int16 values
func Int16ValueMap(src map[string]*int16) map[string]int16 {
dst := make(map[string]int16)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int32 returns a pointer to the int32 value passed in.
func Int32(v int32) *int32 {
return &v
}
// Int32Value returns the value of the int32 pointer passed in or
// 0 if the pointer is nil.
func Int32Value(v *int32) int32 {
if v != nil {
return *v
}
return 0
}
// Int32Slice converts a slice of int32 values into a slice of
// int32 pointers
func Int32Slice(src []int32) []*int32 {
dst := make([]*int32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int32ValueSlice converts a slice of int32 pointers into a slice of
// int32 values
func Int32ValueSlice(src []*int32) []int32 {
dst := make([]int32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int32Map converts a string map of int32 values into a string
// map of int32 pointers
func Int32Map(src map[string]int32) map[string]*int32 {
dst := make(map[string]*int32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int32ValueMap converts a string map of int32 pointers into a string
// map of int32 values
func Int32ValueMap(src map[string]*int32) map[string]int32 {
dst := make(map[string]int32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int64 returns a pointer to the int64 value passed in. // Int64 returns a pointer to the int64 value passed in.
func Int64(v int64) *int64 { func Int64(v int64) *int64 {
return &v return &v
@ -238,6 +474,301 @@ func Int64ValueMap(src map[string]*int64) map[string]int64 {
return dst return dst
} }
// Uint8 returns a pointer to the uint8 value passed in.
func Uint8(v uint8) *uint8 {
return &v
}
// Uint8Value returns the value of the uint8 pointer passed in or
// 0 if the pointer is nil.
func Uint8Value(v *uint8) uint8 {
if v != nil {
return *v
}
return 0
}
// Uint8Slice converts a slice of uint8 values into a slice of
// uint8 pointers
func Uint8Slice(src []uint8) []*uint8 {
dst := make([]*uint8, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint8ValueSlice converts a slice of uint8 pointers into a slice of
// uint8 values
func Uint8ValueSlice(src []*uint8) []uint8 {
dst := make([]uint8, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint8Map converts a string map of uint8 values into a string
// map of uint8 pointers
func Uint8Map(src map[string]uint8) map[string]*uint8 {
dst := make(map[string]*uint8)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint8ValueMap converts a string map of uint8 pointers into a string
// map of uint8 values
func Uint8ValueMap(src map[string]*uint8) map[string]uint8 {
dst := make(map[string]uint8)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint16 returns a pointer to the uint16 value passed in.
func Uint16(v uint16) *uint16 {
return &v
}
// Uint16Value returns the value of the uint16 pointer passed in or
// 0 if the pointer is nil.
func Uint16Value(v *uint16) uint16 {
if v != nil {
return *v
}
return 0
}
// Uint16Slice converts a slice of uint16 values into a slice of
// uint16 pointers
func Uint16Slice(src []uint16) []*uint16 {
dst := make([]*uint16, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint16ValueSlice converts a slice of uint16 pointers into a slice of
// uint16 values
func Uint16ValueSlice(src []*uint16) []uint16 {
dst := make([]uint16, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint16Map converts a string map of uint16 values into a string
// map of uint16 pointers
func Uint16Map(src map[string]uint16) map[string]*uint16 {
dst := make(map[string]*uint16)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint16ValueMap converts a string map of uint16 pointers into a string
// map of uint16 values
func Uint16ValueMap(src map[string]*uint16) map[string]uint16 {
dst := make(map[string]uint16)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint32 returns a pointer to the uint32 value passed in.
func Uint32(v uint32) *uint32 {
return &v
}
// Uint32Value returns the value of the uint32 pointer passed in or
// 0 if the pointer is nil.
func Uint32Value(v *uint32) uint32 {
if v != nil {
return *v
}
return 0
}
// Uint32Slice converts a slice of uint32 values into a slice of
// uint32 pointers
func Uint32Slice(src []uint32) []*uint32 {
dst := make([]*uint32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint32ValueSlice converts a slice of uint32 pointers into a slice of
// uint32 values
func Uint32ValueSlice(src []*uint32) []uint32 {
dst := make([]uint32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint32Map converts a string map of uint32 values into a string
// map of uint32 pointers
func Uint32Map(src map[string]uint32) map[string]*uint32 {
dst := make(map[string]*uint32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint32ValueMap converts a string map of uint32 pointers into a string
// map of uint32 values
func Uint32ValueMap(src map[string]*uint32) map[string]uint32 {
dst := make(map[string]uint32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint64 returns a pointer to the uint64 value passed in.
func Uint64(v uint64) *uint64 {
return &v
}
// Uint64Value returns the value of the uint64 pointer passed in or
// 0 if the pointer is nil.
func Uint64Value(v *uint64) uint64 {
if v != nil {
return *v
}
return 0
}
// Uint64Slice converts a slice of uint64 values into a slice of
// uint64 pointers
func Uint64Slice(src []uint64) []*uint64 {
dst := make([]*uint64, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint64ValueSlice converts a slice of uint64 pointers into a slice of
// uint64 values
func Uint64ValueSlice(src []*uint64) []uint64 {
dst := make([]uint64, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint64Map converts a string map of uint64 values into a string
// map of uint64 pointers
func Uint64Map(src map[string]uint64) map[string]*uint64 {
dst := make(map[string]*uint64)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint64ValueMap converts a string map of uint64 pointers into a string
// map of uint64 values
func Uint64ValueMap(src map[string]*uint64) map[string]uint64 {
dst := make(map[string]uint64)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float32 returns a pointer to the float32 value passed in.
func Float32(v float32) *float32 {
return &v
}
// Float32Value returns the value of the float32 pointer passed in or
// 0 if the pointer is nil.
func Float32Value(v *float32) float32 {
if v != nil {
return *v
}
return 0
}
// Float32Slice converts a slice of float32 values into a slice of
// float32 pointers
func Float32Slice(src []float32) []*float32 {
dst := make([]*float32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Float32ValueSlice converts a slice of float32 pointers into a slice of
// float32 values
func Float32ValueSlice(src []*float32) []float32 {
dst := make([]float32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Float32Map converts a string map of float32 values into a string
// map of float32 pointers
func Float32Map(src map[string]float32) map[string]*float32 {
dst := make(map[string]*float32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Float32ValueMap converts a string map of float32 pointers into a string
// map of float32 values
func Float32ValueMap(src map[string]*float32) map[string]float32 {
dst := make(map[string]float32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float64 returns a pointer to the float64 value passed in. // Float64 returns a pointer to the float64 value passed in.
func Float64(v float64) *float64 { func Float64(v float64) *float64 {
return &v return &v

View file

@ -72,9 +72,9 @@ var ValidateReqSigHandler = request.NamedHandler{
signedTime = r.LastSignedAt signedTime = r.LastSignedAt
} }
// 10 minutes to allow for some clock skew/delays in transmission. // 5 minutes to allow for some clock skew/delays in transmission.
// Would be improved with aws/aws-sdk-go#423 // Would be improved with aws/aws-sdk-go#423
if signedTime.Add(10 * time.Minute).After(time.Now()) { if signedTime.Add(5 * time.Minute).After(time.Now()) {
return return
} }
@ -159,9 +159,9 @@ func handleSendError(r *request.Request, err error) {
Body: ioutil.NopCloser(bytes.NewReader([]byte{})), Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
} }
} }
// Catch all other request errors. // Catch all request errors, and let the default retrier determine
r.Error = awserr.New("RequestError", "send request failed", err) // if the error is retryable.
r.Retryable = aws.Bool(true) // network errors are retryable r.Error = awserr.New(request.ErrCodeRequestError, "send request failed", err)
// Override the error with a context canceled error, if that was canceled. // Override the error with a context canceled error, if that was canceled.
ctx := r.Context() ctx := r.Context()
@ -178,13 +178,15 @@ func handleSendError(r *request.Request, err error) {
var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) { var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 { if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler // this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil) r.Error = awserr.New("UnknownError", "unknown error", r.Error)
} }
}} }}
// AfterRetryHandler performs final checks to determine if the request should // AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay. // be retried and how long to delay.
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) { var AfterRetryHandler = request.NamedHandler{
Name: "core.AfterRetryHandler",
Fn: func(r *request.Request) {
// If one of the other handlers already set the retry state // If one of the other handlers already set the retry state
// we don't want to override it based on the service's state // we don't want to override it based on the service's state
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) { if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
@ -214,7 +216,7 @@ var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn:
r.RetryCount++ r.RetryCount++
r.Error = nil r.Error = nil
} }
}} }}
// ValidateEndpointHandler is a request handler to validate a request had the // ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or // appropriate Region and Endpoint set. Will set r.Error if the endpoint or
@ -223,6 +225,8 @@ var ValidateEndpointHandler = request.NamedHandler{Name: "core.ValidateEndpointH
if r.ClientInfo.SigningRegion == "" && aws.StringValue(r.Config.Region) == "" { if r.ClientInfo.SigningRegion == "" && aws.StringValue(r.Config.Region) == "" {
r.Error = aws.ErrMissingRegion r.Error = aws.ErrMissingRegion
} else if r.ClientInfo.Endpoint == "" { } else if r.ClientInfo.Endpoint == "" {
// Was any endpoint provided by the user, or one was derived by the
// SDK's endpoint resolver?
r.Error = aws.ErrMissingEndpoint r.Error = aws.ErrMissingEndpoint
} }
}} }}

View file

@ -17,7 +17,7 @@ var SDKVersionUserAgentHandler = request.NamedHandler{
} }
const execEnvVar = `AWS_EXECUTION_ENV` const execEnvVar = `AWS_EXECUTION_ENV`
const execEnvUAKey = `exec_env` const execEnvUAKey = `exec-env`
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's // AddHostExecEnvUserAgentHander is a request handler appending the SDK's
// execution environment to the user agent. // execution environment to the user agent.

View file

@ -9,9 +9,7 @@ var (
// providers in the ChainProvider. // providers in the ChainProvider.
// //
// This has been deprecated. For verbose error messaging set // This has been deprecated. For verbose error messaging set
// aws.Config.CredentialsChainVerboseErrors to true // aws.Config.CredentialsChainVerboseErrors to true.
//
// @readonly
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders", ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
`no valid providers in chain. Deprecated. `no valid providers in chain. Deprecated.
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`, For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,

View file

@ -0,0 +1,23 @@
//go:build !go1.7
// +build !go1.7
package credentials
import (
"github.com/aws/aws-sdk-go/internal/context"
)
// backgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func backgroundContext() Context {
return context.BackgroundCtx
}

View file

@ -0,0 +1,21 @@
//go:build go1.7
// +build go1.7
package credentials
import "context"
// backgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func backgroundContext() Context {
return context.Background()
}

View file

@ -0,0 +1,40 @@
//go:build !go1.9
// +build !go1.9
package credentials
import "time"
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext"
// API methods with Go v1.6 and a Context type such as golang.org/x/net/context.
//
// This type, aws.Context, and context.Context are equivalent.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
Value(key interface{}) interface{}
}

View file

@ -0,0 +1,14 @@
//go:build go1.9
// +build go1.9
package credentials
import "context"
// Context is an alias of the Go stdlib's context.Context interface.
// It can be used within the SDK's API operation "WithContext" methods.
//
// This type, aws.Context, and context.Context are equivalent.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context = context.Context

View file

@ -49,8 +49,12 @@
package credentials package credentials
import ( import (
"fmt"
"sync" "sync"
"time" "time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/sync/singleflight"
) )
// AnonymousCredentials is an empty Credential object that can be used as // AnonymousCredentials is an empty Credential object that can be used as
@ -64,8 +68,6 @@ import (
// Credentials: credentials.AnonymousCredentials, // Credentials: credentials.AnonymousCredentials,
// }))) // })))
// // Access public S3 buckets. // // Access public S3 buckets.
//
// @readonly
var AnonymousCredentials = NewStaticCredentials("", "", "") var AnonymousCredentials = NewStaticCredentials("", "", "")
// A Value is the AWS credentials value for individual credential fields. // A Value is the AWS credentials value for individual credential fields.
@ -83,6 +85,12 @@ type Value struct {
ProviderName string ProviderName string
} }
// HasKeys returns if the credentials Value has both AccessKeyID and
// SecretAccessKey value set.
func (v Value) HasKeys() bool {
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
}
// A Provider is the interface for any component which will provide credentials // A Provider is the interface for any component which will provide credentials
// Value. A provider is required to manage its own Expired state, and what to // Value. A provider is required to manage its own Expired state, and what to
// be expired means. // be expired means.
@ -99,6 +107,21 @@ type Provider interface {
IsExpired() bool IsExpired() bool
} }
// ProviderWithContext is a Provider that can retrieve credentials with a Context
type ProviderWithContext interface {
Provider
RetrieveWithContext(Context) (Value, error)
}
// An Expirer is an interface that Providers can implement to expose the expiration
// time, if known. If the Provider cannot accurately provide this info,
// it should not implement this interface.
type Expirer interface {
// The time at which the credentials are no longer valid
ExpiresAt() time.Time
}
// An ErrorProvider is a stub credentials provider that always returns an error // An ErrorProvider is a stub credentials provider that always returns an error
// this is used by the SDK when construction a known provider is not possible // this is used by the SDK when construction a known provider is not possible
// due to an error. // due to an error.
@ -150,7 +173,9 @@ type Expiry struct {
// the expiration time given to ensure no requests are made with expired // the expiration time given to ensure no requests are made with expired
// tokens. // tokens.
func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) { func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) {
e.expiration = expiration // Passed in expirations should have the monotonic clock values stripped.
// This ensures time comparisons will be based on wall-time.
e.expiration = expiration.Round(0)
if window > 0 { if window > 0 {
e.expiration = e.expiration.Add(-window) e.expiration = e.expiration.Add(-window)
} }
@ -165,6 +190,11 @@ func (e *Expiry) IsExpired() bool {
return e.expiration.Before(curTime()) return e.expiration.Before(curTime())
} }
// ExpiresAt returns the expiration time of the credential
func (e *Expiry) ExpiresAt() time.Time {
return e.expiration
}
// A Credentials provides concurrency safe retrieval of AWS credentials Value. // A Credentials provides concurrency safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value // Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials. // expires the next Get will attempt to retrieve valid credentials.
@ -177,20 +207,82 @@ func (e *Expiry) IsExpired() bool {
// first instance of the credentials Value. All calls to Get() after that // first instance of the credentials Value. All calls to Get() after that
// will return the cached credentials Value until IsExpired() returns true. // will return the cached credentials Value until IsExpired() returns true.
type Credentials struct { type Credentials struct {
creds Value sf singleflight.Group
forceRefresh bool
m sync.RWMutex m sync.RWMutex
creds Value
provider Provider provider Provider
} }
// NewCredentials returns a pointer to a new Credentials with the provider set. // NewCredentials returns a pointer to a new Credentials with the provider set.
func NewCredentials(provider Provider) *Credentials { func NewCredentials(provider Provider) *Credentials {
return &Credentials{ c := &Credentials{
provider: provider, provider: provider,
forceRefresh: true,
} }
return c
}
// GetWithContext returns the credentials value, or error if the credentials
// Value failed to be retrieved. Will return early if the passed in context is
// canceled.
//
// Will return the cached credentials Value if it has not expired. If the
// credentials Value has expired the Provider's Retrieve() will be called
// to refresh the credentials.
//
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
//
// Passed in Context is equivalent to aws.Context, and context.Context.
func (c *Credentials) GetWithContext(ctx Context) (Value, error) {
// Check if credentials are cached, and not expired.
select {
case curCreds, ok := <-c.asyncIsExpired():
// ok will only be true, of the credentials were not expired. ok will
// be false and have no value if the credentials are expired.
if ok {
return curCreds, nil
}
case <-ctx.Done():
return Value{}, awserr.New("RequestCanceled",
"request context canceled", ctx.Err())
}
// Cannot pass context down to the actual retrieve, because the first
// context would cancel the whole group when there is not direct
// association of items in the group.
resCh := c.sf.DoChan("", func() (interface{}, error) {
return c.singleRetrieve(&suppressedContext{ctx})
})
select {
case res := <-resCh:
return res.Val.(Value), res.Err
case <-ctx.Done():
return Value{}, awserr.New("RequestCanceled",
"request context canceled", ctx.Err())
}
}
func (c *Credentials) singleRetrieve(ctx Context) (interface{}, error) {
c.m.Lock()
defer c.m.Unlock()
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
return curCreds, nil
}
var creds Value
var err error
if p, ok := c.provider.(ProviderWithContext); ok {
creds, err = p.RetrieveWithContext(ctx)
} else {
creds, err = c.provider.Retrieve()
}
if err == nil {
c.creds = creds
}
return creds, err
} }
// Get returns the credentials value, or error if the credentials Value failed // Get returns the credentials value, or error if the credentials Value failed
@ -203,30 +295,7 @@ func NewCredentials(provider Provider) *Credentials {
// If Credentials.Expire() was called the credentials Value will be force // If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed. // expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) Get() (Value, error) { func (c *Credentials) Get() (Value, error) {
// Check the cached credentials first with just the read lock. return c.GetWithContext(backgroundContext())
c.m.RLock()
if !c.isExpired() {
creds := c.creds
c.m.RUnlock()
return creds, nil
}
c.m.RUnlock()
// Credentials are expired need to retrieve the credentials taking the full
// lock.
c.m.Lock()
defer c.m.Unlock()
if c.isExpired() {
creds, err := c.provider.Retrieve()
if err != nil {
return Value{}, err
}
c.creds = creds
c.forceRefresh = false
}
return c.creds, nil
} }
// Expire expires the credentials and forces them to be retrieved on the // Expire expires the credentials and forces them to be retrieved on the
@ -238,7 +307,7 @@ func (c *Credentials) Expire() {
c.m.Lock() c.m.Lock()
defer c.m.Unlock() defer c.m.Unlock()
c.forceRefresh = true c.creds = Value{}
} }
// IsExpired returns if the credentials are no longer valid, and need // IsExpired returns if the credentials are no longer valid, and need
@ -250,10 +319,65 @@ func (c *Credentials) IsExpired() bool {
c.m.RLock() c.m.RLock()
defer c.m.RUnlock() defer c.m.RUnlock()
return c.isExpired() return c.isExpiredLocked(c.creds)
} }
// isExpired helper method wrapping the definition of expired credentials. // asyncIsExpired returns a channel of credentials Value. If the channel is
func (c *Credentials) isExpired() bool { // closed the credentials are expired and credentials value are not empty.
return c.forceRefresh || c.provider.IsExpired() func (c *Credentials) asyncIsExpired() <-chan Value {
ch := make(chan Value, 1)
go func() {
c.m.RLock()
defer c.m.RUnlock()
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
ch <- curCreds
}
close(ch)
}()
return ch
}
// isExpiredLocked helper method wrapping the definition of expired credentials.
func (c *Credentials) isExpiredLocked(creds interface{}) bool {
return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired()
}
// ExpiresAt provides access to the functionality of the Expirer interface of
// the underlying Provider, if it supports that interface. Otherwise, it returns
// an error.
func (c *Credentials) ExpiresAt() (time.Time, error) {
c.m.RLock()
defer c.m.RUnlock()
expirer, ok := c.provider.(Expirer)
if !ok {
return time.Time{}, awserr.New("ProviderNotExpirer",
fmt.Sprintf("provider %s does not support ExpiresAt()",
c.creds.ProviderName),
nil)
}
if c.creds == (Value{}) {
// set expiration time to the distant past
return time.Time{}, nil
}
return expirer.ExpiresAt(), nil
}
type suppressedContext struct {
Context
}
func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
func (s *suppressedContext) Done() <-chan struct{} {
return nil
}
func (s *suppressedContext) Err() error {
return nil
} }

View file

@ -7,10 +7,12 @@ import (
"strings" "strings"
"time" "time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri" "github.com/aws/aws-sdk-go/internal/sdkuri"
) )
@ -86,7 +88,14 @@ func NewCredentialsWithClient(client *ec2metadata.EC2Metadata, options ...func(*
// Error will be returned if the request fails, or unable to extract // Error will be returned if the request fails, or unable to extract
// the desired credentials. // the desired credentials.
func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) { func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
credsList, err := requestCredList(m.Client) return m.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
credsList, err := requestCredList(ctx, m.Client)
if err != nil { if err != nil {
return credentials.Value{ProviderName: ProviderName}, err return credentials.Value{ProviderName: ProviderName}, err
} }
@ -96,7 +105,7 @@ func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
} }
credsName := credsList[0] credsName := credsList[0]
roleCreds, err := requestCred(m.Client, credsName) roleCreds, err := requestCred(ctx, m.Client, credsName)
if err != nil { if err != nil {
return credentials.Value{ProviderName: ProviderName}, err return credentials.Value{ProviderName: ProviderName}, err
} }
@ -129,8 +138,8 @@ const iamSecurityCredsPath = "iam/security-credentials/"
// requestCredList requests a list of credentials from the EC2 service. // requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request // If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) { func requestCredList(ctx aws.Context, client *ec2metadata.EC2Metadata) ([]string, error) {
resp, err := client.GetMetadata(iamSecurityCredsPath) resp, err := client.GetMetadataWithContext(ctx, iamSecurityCredsPath)
if err != nil { if err != nil {
return nil, awserr.New("EC2RoleRequestError", "no EC2 instance role found", err) return nil, awserr.New("EC2RoleRequestError", "no EC2 instance role found", err)
} }
@ -142,7 +151,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
} }
if err := s.Err(); err != nil { if err := s.Err(); err != nil {
return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err) return nil, awserr.New(request.ErrCodeSerialization,
"failed to read EC2 instance role from metadata service", err)
} }
return credsList, nil return credsList, nil
@ -152,8 +162,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
// //
// If the credentials cannot be found, or there is an error reading the response // If the credentials cannot be found, or there is an error reading the response
// and error will be returned. // and error will be returned.
func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) { func requestCred(ctx aws.Context, client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(sdkuri.PathJoin(iamSecurityCredsPath, credsName)) resp, err := client.GetMetadataWithContext(ctx, sdkuri.PathJoin(iamSecurityCredsPath, credsName))
if err != nil { if err != nil {
return ec2RoleCredRespBody{}, return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError", awserr.New("EC2RoleRequestError",
@ -164,7 +174,7 @@ func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCred
respCreds := ec2RoleCredRespBody{} respCreds := ec2RoleCredRespBody{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil { if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
return ec2RoleCredRespBody{}, return ec2RoleCredRespBody{},
awserr.New("SerializationError", awserr.New(request.ErrCodeSerialization,
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName), fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName),
err) err)
} }

View file

@ -39,6 +39,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client/metadata" "github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
) )
// ProviderName is the name of the credentials provider. // ProviderName is the name of the credentials provider.
@ -65,6 +66,10 @@ type Provider struct {
// //
// If ExpiryWindow is 0 or less it will be ignored. // If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration ExpiryWindow time.Duration
// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
AuthorizationToken string
} }
// NewProviderClient returns a credentials Provider for retrieving AWS credentials // NewProviderClient returns a credentials Provider for retrieving AWS credentials
@ -93,8 +98,8 @@ func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint strin
return p return p
} }
// NewCredentialsClient returns a Credentials wrapper for retrieving credentials // NewCredentialsClient returns a pointer to a new Credentials object
// from an arbitrary endpoint concurrently. The client will request the // wrapping the endpoint credentials Provider.
func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials { func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...)) return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
} }
@ -111,7 +116,13 @@ func (p *Provider) IsExpired() bool {
// Retrieve will attempt to request the credentials from the endpoint the Provider // Retrieve will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails. // was configured for. And error will be returned if the retrieval fails.
func (p *Provider) Retrieve() (credentials.Value, error) { func (p *Provider) Retrieve() (credentials.Value, error) {
resp, err := p.getCredentials() return p.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
resp, err := p.getCredentials(ctx)
if err != nil { if err != nil {
return credentials.Value{ProviderName: ProviderName}, return credentials.Value{ProviderName: ProviderName},
awserr.New("CredentialsEndpointError", "failed to load credentials", err) awserr.New("CredentialsEndpointError", "failed to load credentials", err)
@ -143,7 +154,7 @@ type errorOutput struct {
Message string `json:"message"` Message string `json:"message"`
} }
func (p *Provider) getCredentials() (*getCredentialsOutput, error) { func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error) {
op := &request.Operation{ op := &request.Operation{
Name: "GetCredentials", Name: "GetCredentials",
HTTPMethod: "GET", HTTPMethod: "GET",
@ -151,7 +162,11 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
out := &getCredentialsOutput{} out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out) req := p.Client.NewRequest(op, nil, out)
req.SetContext(ctx)
req.HTTPRequest.Header.Set("Accept", "application/json") req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
}
return out, req.Send() return out, req.Send()
} }
@ -167,7 +182,7 @@ func unmarshalHandler(r *request.Request) {
out := r.Data.(*getCredentialsOutput) out := r.Data.(*getCredentialsOutput)
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil { if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
r.Error = awserr.New("SerializationError", r.Error = awserr.New(request.ErrCodeSerialization,
"failed to decode endpoint credentials", "failed to decode endpoint credentials",
err, err,
) )
@ -178,11 +193,15 @@ func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
var errOut errorOutput var errOut errorOutput
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil { err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body)
r.Error = awserr.New("SerializationError", if err != nil {
"failed to decode endpoint credentials", r.Error = awserr.NewRequestFailure(
err, awserr.New(request.ErrCodeSerialization,
"failed to decode error message", err),
r.HTTPResponse.StatusCode,
r.RequestID,
) )
return
} }
// Response body format is not consistent between metadata endpoints. // Response body format is not consistent between metadata endpoints.

View file

@ -12,14 +12,10 @@ const EnvProviderName = "EnvProvider"
var ( var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be // ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment. // found in the process's environment.
//
// @readonly
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil) ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key // ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment. // can't be found in the process's environment.
//
// @readonly
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil) ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
) )

View file

@ -0,0 +1,426 @@
/*
Package processcreds is a credential Provider to retrieve `credential_process`
credentials.
WARNING: The following describes a method of sourcing credentials from an external
process. This can potentially be dangerous, so proceed with caution. Other
credential providers should be preferred if at all possible. If using this
option, you should make sure that the config file is as locked down as possible
using security best practices for your operating system.
You can use credentials from a `credential_process` in a variety of ways.
One way is to setup your shared config file, located in the default
location, with the `credential_process` key and the command you want to be
called. You also need to set the AWS_SDK_LOAD_CONFIG environment variable
(e.g., `export AWS_SDK_LOAD_CONFIG=1`) to use the shared config file.
[default]
credential_process = /command/to/call
Creating a new session will use the credential process to retrieve credentials.
NOTE: If there are credentials in the profile you are using, the credential
process will not be used.
// Initialize a session to load credentials.
sess, _ := session.NewSession(&aws.Config{
Region: aws.String("us-east-1")},
)
// Create S3 service client to use the credentials.
svc := s3.New(sess)
Another way to use the `credential_process` method is by using
`credentials.NewCredentials()` and providing a command to be executed to
retrieve credentials:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentials("/path/to/command")
// Create service client value configured for credentials.
svc := s3.New(sess, &aws.Config{Credentials: creds})
You can set a non-default timeout for the `credential_process` with another
constructor, `credentials.NewCredentialsTimeout()`, providing the timeout. To
set a one minute timeout:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentialsTimeout(
"/path/to/command",
time.Duration(500) * time.Millisecond)
If you need more control, you can set any configurable options in the
credentials using one or more option functions. For example, you can set a two
minute timeout, a credential duration of 60 minutes, and a maximum stdout
buffer size of 2k.
creds := processcreds.NewCredentials(
"/path/to/command",
func(opt *ProcessProvider) {
opt.Timeout = time.Duration(2) * time.Minute
opt.Duration = time.Duration(60) * time.Minute
opt.MaxBufSize = 2048
})
You can also use your own `exec.Cmd`:
// Create an exec.Cmd
myCommand := exec.Command("/path/to/command")
// Create credentials using your exec.Cmd and custom timeout
creds := processcreds.NewCredentialsCommand(
myCommand,
func(opt *processcreds.ProcessProvider) {
opt.Timeout = time.Duration(1) * time.Second
})
*/
package processcreds
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"runtime"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/sdkio"
)
const (
// ProviderName is the name this credentials provider will label any
// returned credentials Value with.
ProviderName = `ProcessProvider`
// ErrCodeProcessProviderParse error parsing process output
ErrCodeProcessProviderParse = "ProcessProviderParseError"
// ErrCodeProcessProviderVersion version error in output
ErrCodeProcessProviderVersion = "ProcessProviderVersionError"
// ErrCodeProcessProviderRequired required attribute missing in output
ErrCodeProcessProviderRequired = "ProcessProviderRequiredError"
// ErrCodeProcessProviderExecution execution of command failed
ErrCodeProcessProviderExecution = "ProcessProviderExecutionError"
// errMsgProcessProviderTimeout process took longer than allowed
errMsgProcessProviderTimeout = "credential process timed out"
// errMsgProcessProviderProcess process error
errMsgProcessProviderProcess = "error in credential_process"
// errMsgProcessProviderParse problem parsing output
errMsgProcessProviderParse = "parse failed of credential_process output"
// errMsgProcessProviderVersion version error in output
errMsgProcessProviderVersion = "wrong version in process output (not 1)"
// errMsgProcessProviderMissKey missing access key id in output
errMsgProcessProviderMissKey = "missing AccessKeyId in process output"
// errMsgProcessProviderMissSecret missing secret acess key in output
errMsgProcessProviderMissSecret = "missing SecretAccessKey in process output"
// errMsgProcessProviderPrepareCmd prepare of command failed
errMsgProcessProviderPrepareCmd = "failed to prepare command"
// errMsgProcessProviderEmptyCmd command must not be empty
errMsgProcessProviderEmptyCmd = "command must not be empty"
// errMsgProcessProviderPipe failed to initialize pipe
errMsgProcessProviderPipe = "failed to initialize pipe"
// DefaultDuration is the default amount of time in minutes that the
// credentials will be valid for.
DefaultDuration = time.Duration(15) * time.Minute
// DefaultBufSize limits buffer size from growing to an enormous
// amount due to a faulty process.
DefaultBufSize = int(8 * sdkio.KibiByte)
// DefaultTimeout default limit on time a process can run.
DefaultTimeout = time.Duration(1) * time.Minute
)
// ProcessProvider satisfies the credentials.Provider interface, and is a
// client to retrieve credentials from a process.
type ProcessProvider struct {
staticCreds bool
credentials.Expiry
originalCommand []string
// Expiry duration of the credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// A string representing an os command that should return a JSON with
// credential information.
command *exec.Cmd
// MaxBufSize limits memory usage from growing to an enormous
// amount due to a faulty process.
MaxBufSize int
// Timeout limits the time a process can run.
Timeout time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// ProcessProvider. The credentials will expire every 15 minutes by default.
func NewCredentials(command string, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: exec.Command(command),
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// NewCredentialsTimeout returns a pointer to a new Credentials object with
// the specified command and timeout, and default duration and max buffer size.
func NewCredentialsTimeout(command string, timeout time.Duration) *credentials.Credentials {
p := NewCredentials(command, func(opt *ProcessProvider) {
opt.Timeout = timeout
})
return p
}
// NewCredentialsCommand returns a pointer to a new Credentials object with
// the specified command, and default timeout, duration and max buffer size.
func NewCredentialsCommand(command *exec.Cmd, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: command,
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
type credentialProcessResponse struct {
Version int
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string
SessionToken string
Expiration *time.Time
}
// Retrieve executes the 'credential_process' and returns the credentials.
func (p *ProcessProvider) Retrieve() (credentials.Value, error) {
out, err := p.executeCredentialProcess()
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
// Serialize and validate response
resp := &credentialProcessResponse{}
if err = json.Unmarshal(out, resp); err != nil {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderParse,
fmt.Sprintf("%s: %s", errMsgProcessProviderParse, string(out)),
err)
}
if resp.Version != 1 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderVersion,
errMsgProcessProviderVersion,
nil)
}
if len(resp.AccessKeyID) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissKey,
nil)
}
if len(resp.SecretAccessKey) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissSecret,
nil)
}
// Handle expiration
p.staticCreds = resp.Expiration == nil
if resp.Expiration != nil {
p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
}
return credentials.Value{
ProviderName: ProviderName,
AccessKeyID: resp.AccessKeyID,
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.SessionToken,
}, nil
}
// IsExpired returns true if the credentials retrieved are expired, or not yet
// retrieved.
func (p *ProcessProvider) IsExpired() bool {
if p.staticCreds {
return false
}
return p.Expiry.IsExpired()
}
// prepareCommand prepares the command to be executed.
func (p *ProcessProvider) prepareCommand() error {
var cmdArgs []string
if runtime.GOOS == "windows" {
cmdArgs = []string{"cmd.exe", "/C"}
} else {
cmdArgs = []string{"sh", "-c"}
}
if len(p.originalCommand) == 0 {
p.originalCommand = make([]string, len(p.command.Args))
copy(p.originalCommand, p.command.Args)
// check for empty command because it succeeds
if len(strings.TrimSpace(p.originalCommand[0])) < 1 {
return awserr.New(
ErrCodeProcessProviderExecution,
fmt.Sprintf(
"%s: %s",
errMsgProcessProviderPrepareCmd,
errMsgProcessProviderEmptyCmd),
nil)
}
}
cmdArgs = append(cmdArgs, p.originalCommand...)
p.command = exec.Command(cmdArgs[0], cmdArgs[1:]...)
p.command.Env = os.Environ()
return nil
}
// executeCredentialProcess starts the credential process on the OS and
// returns the results or an error.
func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {
if err := p.prepareCommand(); err != nil {
return nil, err
}
// Setup the pipes
outReadPipe, outWritePipe, err := os.Pipe()
if err != nil {
return nil, awserr.New(
ErrCodeProcessProviderExecution,
errMsgProcessProviderPipe,
err)
}
p.command.Stderr = os.Stderr // display stderr on console for MFA
p.command.Stdout = outWritePipe // get creds json on process's stdout
p.command.Stdin = os.Stdin // enable stdin for MFA
output := bytes.NewBuffer(make([]byte, 0, p.MaxBufSize))
stdoutCh := make(chan error, 1)
go readInput(
io.LimitReader(outReadPipe, int64(p.MaxBufSize)),
output,
stdoutCh)
execCh := make(chan error, 1)
go executeCommand(*p.command, execCh)
finished := false
var errors []error
for !finished {
select {
case readError := <-stdoutCh:
errors = appendError(errors, readError)
finished = true
case execError := <-execCh:
err := outWritePipe.Close()
errors = appendError(errors, err)
errors = appendError(errors, execError)
if errors != nil {
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderProcess,
errors)
}
case <-time.After(p.Timeout):
finished = true
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderTimeout,
errors) // errors can be nil
}
}
out := output.Bytes()
if runtime.GOOS == "windows" {
// windows adds slashes to quotes
out = []byte(strings.Replace(string(out), `\"`, `"`, -1))
}
return out, nil
}
// appendError conveniently checks for nil before appending slice
func appendError(errors []error, err error) []error {
if err != nil {
return append(errors, err)
}
return errors
}
func executeCommand(cmd exec.Cmd, exec chan error) {
// Start the command
err := cmd.Start()
if err == nil {
err = cmd.Wait()
}
exec <- err
}
func readInput(r io.Reader, w io.Writer, read chan error) {
tee := io.TeeReader(r, w)
_, err := ioutil.ReadAll(tee)
if err == io.EOF {
err = nil
}
read <- err // will only arrive here when write end of pipe is closed
}

View file

@ -4,9 +4,8 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/go-ini/ini"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/ini"
"github.com/aws/aws-sdk-go/internal/shareddefaults" "github.com/aws/aws-sdk-go/internal/shareddefaults"
) )
@ -18,8 +17,9 @@ var (
ErrSharedCredentialsHomeNotFound = awserr.New("UserHomeNotFound", "user home directory not found.", nil) ErrSharedCredentialsHomeNotFound = awserr.New("UserHomeNotFound", "user home directory not found.", nil)
) )
// A SharedCredentialsProvider retrieves credentials from the current user's home // A SharedCredentialsProvider retrieves access key pair (access key ID,
// directory, and keeps track if those credentials are expired. // secret access key, and session token if present) credentials from the current
// user's home directory, and keeps track if those credentials are expired.
// //
// Profile ini file example: $HOME/.aws/credentials // Profile ini file example: $HOME/.aws/credentials
type SharedCredentialsProvider struct { type SharedCredentialsProvider struct {
@ -77,36 +77,37 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
// The credentials retrieved from the profile will be returned or error. Error will be // The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid. // returned if it fails to read from the file, or the data is invalid.
func loadProfile(filename, profile string) (Value, error) { func loadProfile(filename, profile string) (Value, error) {
config, err := ini.Load(filename) config, err := ini.OpenFile(filename)
if err != nil { if err != nil {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err) return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
} }
iniProfile, err := config.GetSection(profile)
if err != nil { iniProfile, ok := config.GetSection(profile)
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", err) if !ok {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", nil)
} }
id, err := iniProfile.GetKey("aws_access_key_id") id := iniProfile.String("aws_access_key_id")
if err != nil { if len(id) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey", return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename), fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
err) nil)
} }
secret, err := iniProfile.GetKey("aws_secret_access_key") secret := iniProfile.String("aws_secret_access_key")
if err != nil { if len(secret) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret", return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret",
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename), fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
nil) nil)
} }
// Default to empty string if not found // Default to empty string if not found
token := iniProfile.Key("aws_session_token") token := iniProfile.String("aws_session_token")
return Value{ return Value{
AccessKeyID: id.String(), AccessKeyID: id,
SecretAccessKey: secret.String(), SecretAccessKey: secret,
SessionToken: token.String(), SessionToken: token,
ProviderName: SharedCredsProviderName, ProviderName: SharedCredsProviderName,
}, nil }, nil
} }

View file

@ -0,0 +1,60 @@
// Package ssocreds provides a credential provider for retrieving temporary AWS credentials using an SSO access token.
//
// IMPORTANT: The provider in this package does not initiate or perform the AWS SSO login flow. The SDK provider
// expects that you have already performed the SSO login flow using AWS CLI using the "aws sso login" command, or by
// some other mechanism. The provider must find a valid non-expired access token for the AWS SSO user portal URL in
// ~/.aws/sso/cache. If a cached token is not found, it is expired, or the file is malformed an error will be returned.
//
// Loading AWS SSO credentials with the AWS shared configuration file
//
// You can use configure AWS SSO credentials from the AWS shared configuration file by
// providing the specifying the required keys in the profile:
//
// sso_account_id
// sso_region
// sso_role_name
// sso_start_url
//
// For example, the following defines a profile "devsso" and specifies the AWS SSO parameters that defines the target
// account, role, sign-on portal, and the region where the user portal is located. Note: all SSO arguments must be
// provided, or an error will be returned.
//
// [profile devsso]
// sso_start_url = https://my-sso-portal.awsapps.com/start
// sso_role_name = SSOReadOnlyRole
// sso_region = us-east-1
// sso_account_id = 123456789012
//
// Using the config module, you can load the AWS SDK shared configuration, and specify that this profile be used to
// retrieve credentials. For example:
//
// sess, err := session.NewSessionWithOptions(session.Options{
// SharedConfigState: session.SharedConfigEnable,
// Profile: "devsso",
// })
// if err != nil {
// return err
// }
//
// Programmatically loading AWS SSO credentials directly
//
// You can programmatically construct the AWS SSO Provider in your application, and provide the necessary information
// to load and retrieve temporary credentials using an access token from ~/.aws/sso/cache.
//
// svc := sso.New(sess, &aws.Config{
// Region: aws.String("us-west-2"), // Client Region must correspond to the AWS SSO user portal region
// })
//
// provider := ssocreds.NewCredentialsWithClient(svc, "123456789012", "SSOReadOnlyRole", "https://my-sso-portal.awsapps.com/start")
//
// credentials, err := provider.Get()
// if err != nil {
// return err
// }
//
// Additional Resources
//
// Configuring the AWS CLI to use AWS Single Sign-On: https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html
//
// AWS Single Sign-On User Guide: https://docs.aws.amazon.com/singlesignon/latest/userguide/what-is.html
package ssocreds

View file

@ -0,0 +1,10 @@
//go:build !windows
// +build !windows
package ssocreds
import "os"
func getHomeDirectory() string {
return os.Getenv("HOME")
}

View file

@ -0,0 +1,7 @@
package ssocreds
import "os"
func getHomeDirectory() string {
return os.Getenv("USERPROFILE")
}

View file

@ -0,0 +1,180 @@
package ssocreds
import (
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sso"
"github.com/aws/aws-sdk-go/service/sso/ssoiface"
)
// ErrCodeSSOProviderInvalidToken is the code type that is returned if loaded token has expired or is otherwise invalid.
// To refresh the SSO session run aws sso login with the corresponding profile.
const ErrCodeSSOProviderInvalidToken = "SSOProviderInvalidToken"
const invalidTokenMessage = "the SSO session has expired or is invalid"
func init() {
nowTime = time.Now
defaultCacheLocation = defaultCacheLocationImpl
}
var nowTime func() time.Time
// ProviderName is the name of the provider used to specify the source of credentials.
const ProviderName = "SSOProvider"
var defaultCacheLocation func() string
func defaultCacheLocationImpl() string {
return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache")
}
// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token.
type Provider struct {
credentials.Expiry
// The Client which is configured for the AWS Region where the AWS SSO user portal is located.
Client ssoiface.SSOAPI
// The AWS account that is assigned to the user.
AccountID string
// The role name that is assigned to the user.
RoleName string
// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
StartURL string
}
// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured
// for the AWS Region where the AWS SSO user portal is located.
func NewCredentials(configProvider client.ConfigProvider, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
return NewCredentialsWithClient(sso.New(configProvider), accountID, roleName, startURL, optFns...)
}
// NewCredentialsWithClient returns a new AWS Single Sign-On (AWS SSO) credential provider. The provided client is expected to be configured
// for the AWS Region where the AWS SSO user portal is located.
func NewCredentialsWithClient(client ssoiface.SSOAPI, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
p := &Provider{
Client: client,
AccountID: accountID,
RoleName: roleName,
StartURL: startURL,
}
for _, fn := range optFns {
fn(p)
}
return credentials.NewCredentials(p)
}
// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
tokenFile, err := loadTokenFile(p.StartURL)
if err != nil {
return credentials.Value{}, err
}
output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{
AccessToken: &tokenFile.AccessToken,
AccountId: &p.AccountID,
RoleName: &p.RoleName,
})
if err != nil {
return credentials.Value{}, err
}
expireTime := time.Unix(0, aws.Int64Value(output.RoleCredentials.Expiration)*int64(time.Millisecond)).UTC()
p.SetExpiration(expireTime, 0)
return credentials.Value{
AccessKeyID: aws.StringValue(output.RoleCredentials.AccessKeyId),
SecretAccessKey: aws.StringValue(output.RoleCredentials.SecretAccessKey),
SessionToken: aws.StringValue(output.RoleCredentials.SessionToken),
ProviderName: ProviderName,
}, nil
}
func getCacheFileName(url string) (string, error) {
hash := sha1.New()
_, err := hash.Write([]byte(url))
if err != nil {
return "", err
}
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
}
type rfc3339 time.Time
func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
var value string
if err := json.Unmarshal(bytes, &value); err != nil {
return err
}
parse, err := time.Parse(time.RFC3339, value)
if err != nil {
return fmt.Errorf("expected RFC3339 timestamp: %v", err)
}
*r = rfc3339(parse)
return nil
}
type token struct {
AccessToken string `json:"accessToken"`
ExpiresAt rfc3339 `json:"expiresAt"`
Region string `json:"region,omitempty"`
StartURL string `json:"startUrl,omitempty"`
}
func (t token) Expired() bool {
return nowTime().Round(0).After(time.Time(t.ExpiresAt))
}
func loadTokenFile(startURL string) (t token, err error) {
key, err := getCacheFileName(startURL)
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}
fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key))
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}
if err := json.Unmarshal(fileBytes, &t); err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}
if len(t.AccessToken) == 0 {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
}
if t.Expired() {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
}
return t, nil
}

View file

@ -9,8 +9,6 @@ const StaticProviderName = "StaticProvider"
var ( var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty. // ErrStaticCredentialsEmpty is emitted when static credentials are empty.
//
// @readonly
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil) ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
) )
@ -21,7 +19,9 @@ type StaticProvider struct {
} }
// NewStaticCredentials returns a pointer to a new Credentials object // NewStaticCredentials returns a pointer to a new Credentials object
// wrapping a static credentials value provider. // wrapping a static credentials value provider. Token is only required
// for temporary security credentials retrieved via STS, otherwise an empty
// string can be passed for this parameter.
func NewStaticCredentials(id, secret, token string) *Credentials { func NewStaticCredentials(id, secret, token string) *Credentials {
return NewCredentials(&StaticProvider{Value: Value{ return NewCredentials(&StaticProvider{Value: Value{
AccessKeyID: id, AccessKeyID: id,

View file

@ -80,19 +80,22 @@ package stscreds
import ( import (
"fmt" "fmt"
"os"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkrand"
"github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts"
) )
// StdinTokenProvider will prompt on stdout and read from stdin for a string value. // StdinTokenProvider will prompt on stderr and read from stdin for a string value.
// An error is returned if reading from stdin fails. // An error is returned if reading from stdin fails.
// //
// Use this function go read MFA tokens from stdin. The function makes no attempt // Use this function to read MFA tokens from stdin. The function makes no attempt
// to make atomic prompts from stdin across multiple gorouties. // to make atomic prompts from stdin across multiple gorouties.
// //
// Using StdinTokenProvider with multiple AssumeRoleProviders, or Credentials will // Using StdinTokenProvider with multiple AssumeRoleProviders, or Credentials will
@ -102,7 +105,7 @@ import (
// Will wait forever until something is provided on the stdin. // Will wait forever until something is provided on the stdin.
func StdinTokenProvider() (string, error) { func StdinTokenProvider() (string, error) {
var v string var v string
fmt.Printf("Assume Role MFA token code: ") fmt.Fprintf(os.Stderr, "Assume Role MFA token code: ")
_, err := fmt.Scanln(&v) _, err := fmt.Scanln(&v)
return v, err return v, err
@ -116,6 +119,10 @@ type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
} }
type assumeRolerWithContext interface {
AssumeRoleWithContext(aws.Context, *sts.AssumeRoleInput, ...request.Option) (*sts.AssumeRoleOutput, error)
}
// DefaultDuration is the default amount of time in minutes that the credentials // DefaultDuration is the default amount of time in minutes that the credentials
// will be valid for. // will be valid for.
var DefaultDuration = time.Duration(15) * time.Minute var DefaultDuration = time.Duration(15) * time.Minute
@ -142,6 +149,13 @@ type AssumeRoleProvider struct {
// Session name, if you wish to reuse the credentials elsewhere. // Session name, if you wish to reuse the credentials elsewhere.
RoleSessionName string RoleSessionName string
// Optional, you can pass tag key-value pairs to your session. These tags are called session tags.
Tags []*sts.Tag
// A list of keys for session tags that you want to set as transitive.
// If you set a tag key as transitive, the corresponding key and value passes to subsequent sessions in a role chain.
TransitiveTagKeys []*string
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set. // Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration Duration time.Duration
@ -155,6 +169,29 @@ type AssumeRoleProvider struct {
// size. // size.
Policy *string Policy *string
// The ARNs of IAM managed policies you want to use as managed session policies.
// The policies must exist in the same account as the role.
//
// This parameter is optional. You can provide up to 10 managed policy ARNs.
// However, the plain text that you use for both inline and managed session
// policies can't exceed 2,048 characters.
//
// An AWS conversion compresses the passed session policies and session tags
// into a packed binary format that has a separate limit. Your request can fail
// for this limit even if your plain text meets the other requirements. The
// PackedPolicySize response element indicates by percentage how close the policies
// and tags for your request are to the upper size limit.
//
// Passing policies to this operation returns new temporary credentials. The
// resulting session's permissions are the intersection of the role's identity-based
// policy and the session policies. You can use the role's temporary credentials
// in subsequent AWS API calls to access resources in the account that owns
// the role. You cannot use session policies to grant more permissions than
// those allowed by the identity-based policy of the role that is being assumed.
// For more information, see Session Policies (https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html#policies_session)
// in the IAM User Guide.
PolicyArns []*sts.PolicyDescriptorType
// The identification number of the MFA device that is associated with the user // The identification number of the MFA device that is associated with the user
// who is making the AssumeRole call. Specify this value if the trust policy // who is making the AssumeRole call. Specify this value if the trust policy
// of the role being assumed includes a condition that requires MFA authentication. // of the role being assumed includes a condition that requires MFA authentication.
@ -193,11 +230,25 @@ type AssumeRoleProvider struct {
// //
// If ExpiryWindow is 0 or less it will be ignored. // If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration ExpiryWindow time.Duration
// MaxJitterFrac reduces the effective Duration of each credential requested
// by a random percentage between 0 and MaxJitterFraction. MaxJitterFrac must
// have a value between 0 and 1. Any other value may lead to expected behavior.
// With a MaxJitterFrac value of 0, default) will no jitter will be used.
//
// For example, with a Duration of 30m and a MaxJitterFrac of 0.1, the
// AssumeRole call will be made with an arbitrary Duration between 27m and
// 30m.
//
// MaxJitterFrac should not be negative.
MaxJitterFrac float64
} }
// NewCredentials returns a pointer to a new Credentials object wrapping the // NewCredentials returns a pointer to a new Credentials value wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the // AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation. // role will be named after a nanosecond timestamp of this operation. The
// Credentials value will attempt to refresh the credentials using the provider
// when Credentials.Get is called, if the cached credentials are expiring.
// //
// Takes a Config provider to create the STS client. The ConfigProvider is // Takes a Config provider to create the STS client. The ConfigProvider is
// satisfied by the session.Session type. // satisfied by the session.Session type.
@ -219,9 +270,11 @@ func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*As
return credentials.NewCredentials(p) return credentials.NewCredentials(p)
} }
// NewCredentialsWithClient returns a pointer to a new Credentials object wrapping the // NewCredentialsWithClient returns a pointer to a new Credentials value wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the // AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation. // role will be named after a nanosecond timestamp of this operation. The
// Credentials value will attempt to refresh the credentials using the provider
// when Credentials.Get is called, if the cached credentials are expiring.
// //
// Takes an AssumeRoler which can be satisfied by the STS client. // Takes an AssumeRoler which can be satisfied by the STS client.
// //
@ -244,7 +297,11 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*
// Retrieve generates a new set of temporary credentials using STS. // Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) { func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
// Apply defaults where parameters are not set. // Apply defaults where parameters are not set.
if p.RoleSessionName == "" { if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique. // Try to work out a role name that will hopefully end up unique.
@ -254,11 +311,15 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Expire as often as AWS permits. // Expire as often as AWS permits.
p.Duration = DefaultDuration p.Duration = DefaultDuration
} }
jitter := time.Duration(sdkrand.SeededRand.Float64() * p.MaxJitterFrac * float64(p.Duration))
input := &sts.AssumeRoleInput{ input := &sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)), DurationSeconds: aws.Int64(int64((p.Duration - jitter) / time.Second)),
RoleArn: aws.String(p.RoleARN), RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName), RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID, ExternalId: p.ExternalID,
Tags: p.Tags,
PolicyArns: p.PolicyArns,
TransitiveTagKeys: p.TransitiveTagKeys,
} }
if p.Policy != nil { if p.Policy != nil {
input.Policy = p.Policy input.Policy = p.Policy
@ -281,7 +342,15 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
} }
} }
roleOutput, err := p.Client.AssumeRole(input) var roleOutput *sts.AssumeRoleOutput
var err error
if c, ok := p.Client.(assumeRolerWithContext); ok {
roleOutput, err = c.AssumeRoleWithContext(ctx, input)
} else {
roleOutput, err = p.Client.AssumeRole(input)
}
if err != nil { if err != nil {
return credentials.Value{ProviderName: ProviderName}, err return credentials.Value{ProviderName: ProviderName}, err
} }

View file

@ -0,0 +1,182 @@
package stscreds
import (
"fmt"
"io/ioutil"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)
const (
// ErrCodeWebIdentity will be used as an error code when constructing
// a new error to be returned during session creation or retrieval.
ErrCodeWebIdentity = "WebIdentityErr"
// WebIdentityProviderName is the web identity provider name
WebIdentityProviderName = "WebIdentityCredentials"
)
// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = time.Now
// TokenFetcher should return WebIdentity token bytes or an error
type TokenFetcher interface {
FetchToken(credentials.Context) ([]byte, error)
}
// FetchTokenPath is a path to a WebIdentity token file
type FetchTokenPath string
// FetchToken returns a token by reading from the filesystem
func (f FetchTokenPath) FetchToken(ctx credentials.Context) ([]byte, error) {
data, err := ioutil.ReadFile(string(f))
if err != nil {
errMsg := fmt.Sprintf("unable to read file at %s", f)
return nil, awserr.New(ErrCodeWebIdentity, errMsg, err)
}
return data, nil
}
// WebIdentityRoleProvider is used to retrieve credentials using
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry
// The policy ARNs to use with the web identity assumed role.
PolicyArns []*sts.PolicyDescriptorType
// Duration the STS credentials will be valid for. Truncated to seconds.
// If unset, the assumed role will use AssumeRoleWithWebIdentity's default
// expiry duration. See
// https://docs.aws.amazon.com/sdk-for-go/api/service/sts/#STS.AssumeRoleWithWebIdentity
// for more information.
Duration time.Duration
// The amount of time the credentials will be refreshed before they expire.
// This is useful refresh credentials before they expire to reduce risk of
// using credentials as they expire. If unset, will default to no expiry
// window.
ExpiryWindow time.Duration
client stsiface.STSAPI
tokenFetcher TokenFetcher
roleARN string
roleSessionName string
}
// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options, and wrap with credentials.NewCredentials helper.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
return credentials.NewCredentials(p)
}
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options.
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, FetchTokenPath(path))
}
// NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI and a TokenFetcher
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options.
func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider {
return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, tokenFetcher)
}
// NewWebIdentityRoleProviderWithOptions will return an initialize
// WebIdentityRoleProvider with the provided stsiface.STSAPI, role ARN, and a
// TokenFetcher. Additional options can be provided as functional options.
//
// TokenFetcher is the implementation that will retrieve the JWT token from to
// assume the role with. Use the provided FetchTokenPath implementation to
// retrieve the JWT token using a file system path.
func NewWebIdentityRoleProviderWithOptions(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher, optFns ...func(*WebIdentityRoleProvider)) *WebIdentityRoleProvider {
p := WebIdentityRoleProvider{
client: svc,
tokenFetcher: tokenFetcher,
roleARN: roleARN,
roleSessionName: roleSessionName,
}
for _, fn := range optFns {
fn(&p)
}
return &p
}
// Retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext will attempt to assume a role from a token which is
// located at 'WebIdentityTokenFilePath' specified destination and if that is
// empty an error will be returned.
func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
b, err := p.tokenFetcher.FetchToken(ctx)
if err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed fetching WebIdentity token: ", err)
}
sessionName := p.roleSessionName
if len(sessionName) == 0 {
// session name is used to uniquely identify a session. This simply
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(now().UnixNano(), 10)
}
var duration *int64
if p.Duration != 0 {
duration = aws.Int64(int64(p.Duration / time.Second))
}
req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
PolicyArns: p.PolicyArns,
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
DurationSeconds: duration,
})
req.SetContext(ctx)
// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.
req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
if err := req.Send(); err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
}
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
value := credentials.Value{
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
SessionToken: aws.StringValue(resp.Credentials.SessionToken),
ProviderName: WebIdentityProviderName,
}
return value, nil
}

View file

@ -1,30 +1,61 @@
// Package csm provides Client Side Monitoring (CSM) which enables sending metrics // Package csm provides the Client Side Monitoring (CSM) client which enables
// via UDP connection. Using the Start function will enable the reporting of // sending metrics via UDP connection to the CSM agent. This package provides
// metrics on a given port. If Start is called, with different parameters, again, // control options, and configuration for the CSM client. The client can be
// a panic will occur. // controlled manually, or automatically via the SDK's Session configuration.
// //
// Pause can be called to pause any metrics publishing on a given port. Sessions // Enabling CSM client via SDK's Session configuration
// that have had their handlers modified via InjectHandlers may still be used. //
// However, the handlers will act as a no-op meaning no metrics will be published. // The CSM client can be enabled automatically via SDK's Session configuration.
// The SDK's session configuration enables the CSM client if the AWS_CSM_PORT
// environment variable is set to a non-empty value.
//
// The configuration options for the CSM client via the SDK's session
// configuration are:
//
// * AWS_CSM_PORT=<port number>
// The port number the CSM agent will receive metrics on.
//
// * AWS_CSM_HOST=<hostname or ip>
// The hostname, or IP address the CSM agent will receive metrics on.
// Without port number.
//
// Manually enabling the CSM client
//
// The CSM client can be started, paused, and resumed manually. The Start
// function will enable the CSM client to publish metrics to the CSM agent. It
// is safe to call Start concurrently, but if Start is called additional times
// with different ClientID or address it will panic.
// //
// Example:
// r, err := csm.Start("clientID", ":31000") // r, err := csm.Start("clientID", ":31000")
// if err != nil { // if err != nil {
// panic(fmt.Errorf("failed starting CSM: %v", err)) // panic(fmt.Errorf("failed starting CSM: %v", err))
// } // }
// //
// When controlling the CSM client manually, you must also inject its request
// handlers into the SDK's Session configuration for the SDK's API clients to
// publish metrics.
//
// sess, err := session.NewSession(&aws.Config{}) // sess, err := session.NewSession(&aws.Config{})
// if err != nil { // if err != nil {
// panic(fmt.Errorf("failed loading session: %v", err)) // panic(fmt.Errorf("failed loading session: %v", err))
// } // }
// //
// // Add CSM client's metric publishing request handlers to the SDK's
// // Session Configuration.
// r.InjectHandlers(&sess.Handlers) // r.InjectHandlers(&sess.Handlers)
// //
// client := s3.New(sess) // Controlling CSM client
// resp, err := client.GetObject(&s3.GetObjectInput{ //
// Bucket: aws.String("bucket"), // Once the CSM client has been enabled the Get function will return a Reporter
// Key: aws.String("key"), // value that you can use to pause and resume the metrics published to the CSM
// }) // agent. If Get function is called before the reporter is enabled with the
// Start function or via SDK's Session configuration nil will be returned.
//
// The Pause method can be called to stop the CSM client publishing metrics to
// the CSM agent. The Continue method will resume metric publishing.
//
// // Get the CSM client Reporter.
// r := csm.Get()
// //
// // Will pause monitoring // // Will pause monitoring
// r.Pause() // r.Pause()
@ -35,12 +66,4 @@
// //
// // Resume monitoring // // Resume monitoring
// r.Continue() // r.Continue()
//
// Start returns a Reporter that is used to enable or disable monitoring. If
// access to the Reporter is required later, calling Get will return the Reporter
// singleton.
//
// Example:
// r := csm.Get()
// r.Continue()
package csm package csm

View file

@ -2,6 +2,7 @@ package csm
import ( import (
"fmt" "fmt"
"strings"
"sync" "sync"
) )
@ -9,19 +10,40 @@ var (
lock sync.Mutex lock sync.Mutex
) )
// Client side metric handler names
const ( const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric" // DefaultPort is used when no port is specified.
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric" DefaultPort = "31000"
// DefaultHost is the host that will be used when none is specified.
DefaultHost = "127.0.0.1"
) )
// Start will start the a long running go routine to capture // AddressWithDefaults returns a CSM address built from the host and port
// values. If the host or port is not set, default values will be used
// instead. If host is "localhost" it will be replaced with "127.0.0.1".
func AddressWithDefaults(host, port string) string {
if len(host) == 0 || strings.EqualFold(host, "localhost") {
host = DefaultHost
}
if len(port) == 0 {
port = DefaultPort
}
// Only IP6 host can contain a colon
if strings.Contains(host, ":") {
return "[" + host + "]:" + port
}
return host + ":" + port
}
// Start will start a long running go routine to capture
// client side metrics. Calling start multiple time will only // client side metrics. Calling start multiple time will only
// start the metric listener once and will panic if a different // start the metric listener once and will panic if a different
// client ID or port is passed in. // client ID or port is passed in.
// //
// Example: // r, err := csm.Start("clientID", "127.0.0.1:31000")
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil { // if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err)) // panic(fmt.Errorf("expected no error, but received %v", err))
// } // }

View file

@ -3,6 +3,8 @@ package csm
import ( import (
"strconv" "strconv"
"time" "time"
"github.com/aws/aws-sdk-go/aws"
) )
type metricTime time.Time type metricTime time.Time
@ -39,6 +41,12 @@ type metric struct {
SDKException *string `json:"SdkException,omitempty"` SDKException *string `json:"SdkException,omitempty"`
SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"` SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"`
FinalHTTPStatusCode *int `json:"FinalHttpStatusCode,omitempty"`
FinalAWSException *string `json:"FinalAwsException,omitempty"`
FinalAWSExceptionMessage *string `json:"FinalAwsExceptionMessage,omitempty"`
FinalSDKException *string `json:"FinalSdkException,omitempty"`
FinalSDKExceptionMessage *string `json:"FinalSdkExceptionMessage,omitempty"`
DestinationIP *string `json:"DestinationIp,omitempty"` DestinationIP *string `json:"DestinationIp,omitempty"`
ConnectionReused *int `json:"ConnectionReused,omitempty"` ConnectionReused *int `json:"ConnectionReused,omitempty"`
@ -48,4 +56,54 @@ type metric struct {
DNSLatency *int `json:"DnsLatency,omitempty"` DNSLatency *int `json:"DnsLatency,omitempty"`
TCPLatency *int `json:"TcpLatency,omitempty"` TCPLatency *int `json:"TcpLatency,omitempty"`
SSLLatency *int `json:"SslLatency,omitempty"` SSLLatency *int `json:"SslLatency,omitempty"`
MaxRetriesExceeded *int `json:"MaxRetriesExceeded,omitempty"`
}
func (m *metric) TruncateFields() {
m.ClientID = truncateString(m.ClientID, 255)
m.UserAgent = truncateString(m.UserAgent, 256)
m.AWSException = truncateString(m.AWSException, 128)
m.AWSExceptionMessage = truncateString(m.AWSExceptionMessage, 512)
m.SDKException = truncateString(m.SDKException, 128)
m.SDKExceptionMessage = truncateString(m.SDKExceptionMessage, 512)
m.FinalAWSException = truncateString(m.FinalAWSException, 128)
m.FinalAWSExceptionMessage = truncateString(m.FinalAWSExceptionMessage, 512)
m.FinalSDKException = truncateString(m.FinalSDKException, 128)
m.FinalSDKExceptionMessage = truncateString(m.FinalSDKExceptionMessage, 512)
}
func truncateString(v *string, l int) *string {
if v != nil && len(*v) > l {
nv := (*v)[:l]
return &nv
}
return v
}
func (m *metric) SetException(e metricException) {
switch te := e.(type) {
case awsException:
m.AWSException = aws.String(te.exception)
m.AWSExceptionMessage = aws.String(te.message)
case sdkException:
m.SDKException = aws.String(te.exception)
m.SDKExceptionMessage = aws.String(te.message)
}
}
func (m *metric) SetFinalException(e metricException) {
switch te := e.(type) {
case awsException:
m.FinalAWSException = aws.String(te.exception)
m.FinalAWSExceptionMessage = aws.String(te.message)
case sdkException:
m.FinalSDKException = aws.String(te.exception)
m.FinalSDKExceptionMessage = aws.String(te.message)
}
} }

View file

@ -16,25 +16,26 @@ var (
type metricChan struct { type metricChan struct {
ch chan metric ch chan metric
paused int64 paused *int64
} }
func newMetricChan(size int) metricChan { func newMetricChan(size int) metricChan {
return metricChan{ return metricChan{
ch: make(chan metric, size), ch: make(chan metric, size),
paused: new(int64),
} }
} }
func (ch *metricChan) Pause() { func (ch *metricChan) Pause() {
atomic.StoreInt64(&ch.paused, pausedEnum) atomic.StoreInt64(ch.paused, pausedEnum)
} }
func (ch *metricChan) Continue() { func (ch *metricChan) Continue() {
atomic.StoreInt64(&ch.paused, runningEnum) atomic.StoreInt64(ch.paused, runningEnum)
} }
func (ch *metricChan) IsPaused() bool { func (ch *metricChan) IsPaused() bool {
v := atomic.LoadInt64(&ch.paused) v := atomic.LoadInt64(ch.paused)
return v == pausedEnum return v == pausedEnum
} }

View file

@ -0,0 +1,26 @@
package csm
type metricException interface {
Exception() string
Message() string
}
type requestException struct {
exception string
message string
}
func (e requestException) Exception() string {
return e.exception
}
func (e requestException) Message() string {
return e.message
}
type awsException struct {
requestException
}
type sdkException struct {
requestException
}

View file

@ -10,11 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
) )
const (
// DefaultPort is used when no port is specified
DefaultPort = "31000"
)
// Reporter will gather metrics of API requests made and // Reporter will gather metrics of API requests made and
// send those metrics to the CSM endpoint. // send those metrics to the CSM endpoint.
type Reporter struct { type Reporter struct {
@ -71,7 +66,6 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
XAmzRequestID: aws.String(r.RequestID), XAmzRequestID: aws.String(r.RequestID),
AttemptCount: aws.Int(r.RetryCount + 1),
AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))), AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))),
AccessKey: aws.String(creds.AccessKeyID), AccessKey: aws.String(creds.AccessKeyID),
} }
@ -82,26 +76,29 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
if r.Error != nil { if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok { if awserr, ok := r.Error.(awserr.Error); ok {
setError(&m, awserr) m.SetException(getMetricException(awserr))
} }
} }
m.TruncateFields()
rep.metricsCh.Push(m) rep.metricsCh.Push(m)
} }
func setError(m *metric, err awserr.Error) { func getMetricException(err awserr.Error) metricException {
msg := err.Error() msg := err.Error()
code := err.Code() code := err.Code()
switch code { switch code {
case "RequestError", case request.ErrCodeRequestError,
"SerializationError", request.ErrCodeSerialization,
request.CanceledErrorCode: request.CanceledErrorCode:
m.SDKException = &code return sdkException{
m.SDKExceptionMessage = &msg requestException{exception: code, message: msg},
}
default: default:
m.AWSException = &code return awsException{
m.AWSExceptionMessage = &msg requestException{exception: code, message: msg},
}
} }
} }
@ -116,12 +113,27 @@ func (rep *Reporter) sendAPICallMetric(r *request.Request) {
API: aws.String(r.Operation.Name), API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID), Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now), Timestamp: (*metricTime)(&now),
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
Type: aws.String("ApiCall"), Type: aws.String("ApiCall"),
AttemptCount: aws.Int(r.RetryCount + 1), AttemptCount: aws.Int(r.RetryCount + 1),
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)), Region: r.Config.Region,
Latency: aws.Int(int(time.Since(r.Time) / time.Millisecond)),
XAmzRequestID: aws.String(r.RequestID), XAmzRequestID: aws.String(r.RequestID),
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
} }
if r.HTTPResponse != nil {
m.FinalHTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
}
if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok {
m.SetFinalException(getMetricException(awserr))
}
}
m.TruncateFields()
// TODO: Probably want to figure something out for logging dropped // TODO: Probably want to figure something out for logging dropped
// metrics // metrics
rep.metricsCh.Push(m) rep.metricsCh.Push(m)
@ -172,8 +184,9 @@ func (rep *Reporter) start() {
} }
} }
// Pause will pause the metric channel preventing any new metrics from // Pause will pause the metric channel preventing any new metrics from being
// being added. // added. It is safe to call concurrently with other calls to Pause, but if
// called concurently with Continue can lead to unexpected state.
func (rep *Reporter) Pause() { func (rep *Reporter) Pause() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
@ -185,8 +198,9 @@ func (rep *Reporter) Pause() {
rep.close() rep.close()
} }
// Continue will reopen the metric channel and allow for monitoring // Continue will reopen the metric channel and allow for monitoring to be
// to be resumed. // resumed. It is safe to call concurrently with other calls to Continue, but
// if called concurently with Pause can lead to unexpected state.
func (rep *Reporter) Continue() { func (rep *Reporter) Continue() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
@ -201,10 +215,18 @@ func (rep *Reporter) Continue() {
rep.metricsCh.Continue() rep.metricsCh.Continue()
} }
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
)
// InjectHandlers will will enable client side metrics and inject the proper // InjectHandlers will will enable client side metrics and inject the proper
// handlers to handle how metrics are sent. // handlers to handle how metrics are sent.
// //
// Example: // InjectHandlers is NOT safe to call concurrently. Calling InjectHandlers
// multiple times may lead to unexpected behavior, (e.g. duplicate metrics).
//
// // Start must be called in order to inject the correct handlers // // Start must be called in order to inject the correct handlers
// r, err := csm.Start("clientID", "127.0.0.1:8094") // r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil { // if err != nil {
@ -221,11 +243,22 @@ func (rep *Reporter) InjectHandlers(handlers *request.Handlers) {
return return
} }
apiCallHandler := request.NamedHandler{Name: APICallMetricHandlerName, Fn: rep.sendAPICallMetric} handlers.Complete.PushFrontNamed(request.NamedHandler{
apiCallAttemptHandler := request.NamedHandler{Name: APICallAttemptMetricHandlerName, Fn: rep.sendAPICallAttemptMetric} Name: APICallMetricHandlerName,
Fn: rep.sendAPICallMetric,
})
handlers.Complete.PushFrontNamed(apiCallHandler) handlers.CompleteAttempt.PushFrontNamed(request.NamedHandler{
handlers.Complete.PushFrontNamed(apiCallAttemptHandler) Name: APICallAttemptMetricHandlerName,
Fn: rep.sendAPICallAttemptMetric,
handlers.AfterRetry.PushFrontNamed(apiCallAttemptHandler) })
}
// boolIntValue return 1 for true and 0 for false.
func boolIntValue(b bool) int {
if b {
return 1
}
return 0
} }

View file

@ -24,6 +24,7 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
) )
// A Defaults provides a collection of default values for SDK clients. // A Defaults provides a collection of default values for SDK clients.
@ -112,8 +113,8 @@ func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Pro
} }
const ( const (
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI" httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
ecsCredsProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
) )
// RemoteCredProvider returns a credentials provider for the default remote // RemoteCredProvider returns a credentials provider for the default remote
@ -123,8 +124,8 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
return localHTTPCredProvider(cfg, handlers, u) return localHTTPCredProvider(cfg, handlers, u)
} }
if uri := os.Getenv(ecsCredsProviderEnvVar); len(uri) > 0 { if uri := os.Getenv(shareddefaults.ECSCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("http://169.254.170.2%s", uri) u := fmt.Sprintf("%s%s", shareddefaults.ECSContainerCredentialsURI, uri)
return httpCredProvider(cfg, handlers, u) return httpCredProvider(cfg, handlers, u)
} }
@ -187,6 +188,7 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede
return endpointcreds.NewProviderClient(cfg, handlers, u, return endpointcreds.NewProviderClient(cfg, handlers, u,
func(p *endpointcreds.Provider) { func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute p.ExpiryWindow = 5 * time.Minute
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
}, },
) )
} }

View file

@ -4,72 +4,138 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri" "github.com/aws/aws-sdk-go/internal/sdkuri"
) )
// getToken uses the duration to return a token for EC2 metadata service,
// or an error if the request failed.
func (c *EC2Metadata) getToken(ctx aws.Context, duration time.Duration) (tokenOutput, error) {
op := &request.Operation{
Name: "GetToken",
HTTPMethod: "PUT",
HTTPPath: "/latest/api/token",
}
var output tokenOutput
req := c.NewRequest(op, nil, &output)
req.SetContext(ctx)
// remove the fetch token handler from the request handlers to avoid infinite recursion
req.Handlers.Sign.RemoveByName(fetchTokenHandlerName)
// Swap the unmarshalMetadataHandler with unmarshalTokenHandler on this request.
req.Handlers.Unmarshal.Swap(unmarshalMetadataHandlerName, unmarshalTokenHandler)
ttl := strconv.FormatInt(int64(duration/time.Second), 10)
req.HTTPRequest.Header.Set(ttlHeader, ttl)
err := req.Send()
// Errors with bad request status should be returned.
if err != nil {
err = awserr.NewRequestFailure(
awserr.New(req.HTTPResponse.Status, http.StatusText(req.HTTPResponse.StatusCode), err),
req.HTTPResponse.StatusCode, req.RequestID)
}
return output, err
}
// GetMetadata uses the path provided to request information from the EC2 // GetMetadata uses the path provided to request information from the EC2
// instance metdata service. The content will be returned as a string, or // instance metadata service. The content will be returned as a string, or
// error if the request failed. // error if the request failed.
func (c *EC2Metadata) GetMetadata(p string) (string, error) { func (c *EC2Metadata) GetMetadata(p string) (string, error) {
return c.GetMetadataWithContext(aws.BackgroundContext(), p)
}
// GetMetadataWithContext uses the path provided to request information from the EC2
// instance metadata service. The content will be returned as a string, or
// error if the request failed.
func (c *EC2Metadata) GetMetadataWithContext(ctx aws.Context, p string) (string, error) {
op := &request.Operation{ op := &request.Operation{
Name: "GetMetadata", Name: "GetMetadata",
HTTPMethod: "GET", HTTPMethod: "GET",
HTTPPath: sdkuri.PathJoin("/meta-data", p), HTTPPath: sdkuri.PathJoin("/latest/meta-data", p),
} }
output := &metadataOutput{} output := &metadataOutput{}
req := c.NewRequest(op, nil, output) req := c.NewRequest(op, nil, output)
return output.Content, req.Send() req.SetContext(ctx)
err := req.Send()
return output.Content, err
} }
// GetUserData returns the userdata that was configured for the service. If // GetUserData returns the userdata that was configured for the service. If
// there is no user-data setup for the EC2 instance a "NotFoundError" error // there is no user-data setup for the EC2 instance a "NotFoundError" error
// code will be returned. // code will be returned.
func (c *EC2Metadata) GetUserData() (string, error) { func (c *EC2Metadata) GetUserData() (string, error) {
return c.GetUserDataWithContext(aws.BackgroundContext())
}
// GetUserDataWithContext returns the userdata that was configured for the service. If
// there is no user-data setup for the EC2 instance a "NotFoundError" error
// code will be returned.
func (c *EC2Metadata) GetUserDataWithContext(ctx aws.Context) (string, error) {
op := &request.Operation{ op := &request.Operation{
Name: "GetUserData", Name: "GetUserData",
HTTPMethod: "GET", HTTPMethod: "GET",
HTTPPath: "/user-data", HTTPPath: "/latest/user-data",
} }
output := &metadataOutput{} output := &metadataOutput{}
req := c.NewRequest(op, nil, output) req := c.NewRequest(op, nil, output)
req.Handlers.UnmarshalError.PushBack(func(r *request.Request) { req.SetContext(ctx)
if r.HTTPResponse.StatusCode == http.StatusNotFound {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
}
})
return output.Content, req.Send() err := req.Send()
return output.Content, err
} }
// GetDynamicData uses the path provided to request information from the EC2 // GetDynamicData uses the path provided to request information from the EC2
// instance metadata service for dynamic data. The content will be returned // instance metadata service for dynamic data. The content will be returned
// as a string, or error if the request failed. // as a string, or error if the request failed.
func (c *EC2Metadata) GetDynamicData(p string) (string, error) { func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
return c.GetDynamicDataWithContext(aws.BackgroundContext(), p)
}
// GetDynamicDataWithContext uses the path provided to request information from the EC2
// instance metadata service for dynamic data. The content will be returned
// as a string, or error if the request failed.
func (c *EC2Metadata) GetDynamicDataWithContext(ctx aws.Context, p string) (string, error) {
op := &request.Operation{ op := &request.Operation{
Name: "GetDynamicData", Name: "GetDynamicData",
HTTPMethod: "GET", HTTPMethod: "GET",
HTTPPath: sdkuri.PathJoin("/dynamic", p), HTTPPath: sdkuri.PathJoin("/latest/dynamic", p),
} }
output := &metadataOutput{} output := &metadataOutput{}
req := c.NewRequest(op, nil, output) req := c.NewRequest(op, nil, output)
req.SetContext(ctx)
return output.Content, req.Send() err := req.Send()
return output.Content, err
} }
// GetInstanceIdentityDocument retrieves an identity document describing an // GetInstanceIdentityDocument retrieves an identity document describing an
// instance. Error is returned if the request fails or is unable to parse // instance. Error is returned if the request fails or is unable to parse
// the response. // the response.
func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument, error) { func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument, error) {
resp, err := c.GetDynamicData("instance-identity/document") return c.GetInstanceIdentityDocumentWithContext(aws.BackgroundContext())
}
// GetInstanceIdentityDocumentWithContext retrieves an identity document describing an
// instance. Error is returned if the request fails or is unable to parse
// the response.
func (c *EC2Metadata) GetInstanceIdentityDocumentWithContext(ctx aws.Context) (EC2InstanceIdentityDocument, error) {
resp, err := c.GetDynamicDataWithContext(ctx, "instance-identity/document")
if err != nil { if err != nil {
return EC2InstanceIdentityDocument{}, return EC2InstanceIdentityDocument{},
awserr.New("EC2MetadataRequestError", awserr.New("EC2MetadataRequestError",
@ -79,7 +145,7 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument
doc := EC2InstanceIdentityDocument{} doc := EC2InstanceIdentityDocument{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil { if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil {
return EC2InstanceIdentityDocument{}, return EC2InstanceIdentityDocument{},
awserr.New("SerializationError", awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 instance identity document", err) "failed to decode EC2 instance identity document", err)
} }
@ -88,7 +154,12 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument
// IAMInfo retrieves IAM info from the metadata API // IAMInfo retrieves IAM info from the metadata API
func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) { func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
resp, err := c.GetMetadata("iam/info") return c.IAMInfoWithContext(aws.BackgroundContext())
}
// IAMInfoWithContext retrieves IAM info from the metadata API
func (c *EC2Metadata) IAMInfoWithContext(ctx aws.Context) (EC2IAMInfo, error) {
resp, err := c.GetMetadataWithContext(ctx, "iam/info")
if err != nil { if err != nil {
return EC2IAMInfo{}, return EC2IAMInfo{},
awserr.New("EC2MetadataRequestError", awserr.New("EC2MetadataRequestError",
@ -98,7 +169,7 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
info := EC2IAMInfo{} info := EC2IAMInfo{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil { if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil {
return EC2IAMInfo{}, return EC2IAMInfo{},
awserr.New("SerializationError", awserr.New(request.ErrCodeSerialization,
"failed to decode EC2 IAM info", err) "failed to decode EC2 IAM info", err)
} }
@ -113,20 +184,36 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
// Region returns the region the instance is running in. // Region returns the region the instance is running in.
func (c *EC2Metadata) Region() (string, error) { func (c *EC2Metadata) Region() (string, error) {
resp, err := c.GetMetadata("placement/availability-zone") return c.RegionWithContext(aws.BackgroundContext())
}
// RegionWithContext returns the region the instance is running in.
func (c *EC2Metadata) RegionWithContext(ctx aws.Context) (string, error) {
ec2InstanceIdentityDocument, err := c.GetInstanceIdentityDocumentWithContext(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
// extract region from the ec2InstanceIdentityDocument
// returns region without the suffix. Eg: us-west-2a becomes us-west-2 region := ec2InstanceIdentityDocument.Region
return resp[:len(resp)-1], nil if len(region) == 0 {
return "", awserr.New("EC2MetadataError", "invalid region received for ec2metadata instance", nil)
}
// returns region
return region, nil
} }
// Available returns if the application has access to the EC2 Metadata service. // Available returns if the application has access to the EC2 Metadata service.
// Can be used to determine if application is running within an EC2 Instance and // Can be used to determine if application is running within an EC2 Instance and
// the metadata service is available. // the metadata service is available.
func (c *EC2Metadata) Available() bool { func (c *EC2Metadata) Available() bool {
if _, err := c.GetMetadata("instance-id"); err != nil { return c.AvailableWithContext(aws.BackgroundContext())
}
// AvailableWithContext returns if the application has access to the EC2 Metadata service.
// Can be used to determine if application is running within an EC2 Instance and
// the metadata service is available.
func (c *EC2Metadata) AvailableWithContext(ctx aws.Context) bool {
if _, err := c.GetMetadataWithContext(ctx, "instance-id"); err != nil {
return false return false
} }
@ -146,6 +233,7 @@ type EC2IAMInfo struct {
// an instance identity document // an instance identity document
type EC2InstanceIdentityDocument struct { type EC2InstanceIdentityDocument struct {
DevpayProductCodes []string `json:"devpayProductCodes"` DevpayProductCodes []string `json:"devpayProductCodes"`
MarketplaceProductCodes []string `json:"marketplaceProductCodes"`
AvailabilityZone string `json:"availabilityZone"` AvailabilityZone string `json:"availabilityZone"`
PrivateIP string `json:"privateIp"` PrivateIP string `json:"privateIp"`
Version string `json:"version"` Version string `json:"version"`

View file

@ -4,15 +4,20 @@
// This package's client can be disabled completely by setting the environment // This package's client can be disabled completely by setting the environment
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to // variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot // true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environemnt variable is set to true, (case insensitive). // be used while the environment variable is set to true, (case insensitive).
//
// The endpoint of the EC2 IMDS client can be configured via the environment
// variable, AWS_EC2_METADATA_SERVICE_ENDPOINT when creating the client with a
// Session. See aws/session#Options.EC2IMDSEndpoint for more details.
package ec2metadata package ec2metadata
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"net/http" "net/http"
"net/url"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
@ -24,9 +29,25 @@ import (
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
) )
// ServiceName is the name of the service. const (
const ServiceName = "ec2metadata" // ServiceName is the name of the service.
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED" ServiceName = "ec2metadata"
disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
// Headers for Token and TTL
ttlHeader = "x-aws-ec2-metadata-token-ttl-seconds"
tokenHeader = "x-aws-ec2-metadata-token"
// Named Handler constants
fetchTokenHandlerName = "FetchTokenHandler"
unmarshalMetadataHandlerName = "unmarshalMetadataHandler"
unmarshalTokenHandlerName = "unmarshalTokenHandler"
enableTokenProviderHandlerName = "enableTokenProviderHandler"
// TTL constants
defaultTTL = 21600 * time.Second
ttlExpirationWindow = 30 * time.Second
)
// A EC2Metadata is an EC2 Metadata service Client. // A EC2Metadata is an EC2 Metadata service Client.
type EC2Metadata struct { type EC2Metadata struct {
@ -52,6 +73,9 @@ func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata {
// a client when not using a session. Generally using just New with a session // a client when not using a session. Generally using just New with a session
// is preferred. // is preferred.
// //
// Will remove the URL path from the endpoint provided to ensure the EC2 IMDS
// client is able to communicate with the EC2 IMDS API.
//
// If an unmodified HTTP client is provided from the stdlib default, or no client // If an unmodified HTTP client is provided from the stdlib default, or no client
// the EC2RoleProvider's EC2Metadata HTTP client's timeout will be shortened. // the EC2RoleProvider's EC2Metadata HTTP client's timeout will be shortened.
// To disable this set Config.EC2MetadataDisableTimeoutOverride to false. Enabled by default. // To disable this set Config.EC2MetadataDisableTimeoutOverride to false. Enabled by default.
@ -63,8 +87,19 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
// use a shorter timeout than default because the metadata // use a shorter timeout than default because the metadata
// service is local if it is running, and to fail faster // service is local if it is running, and to fail faster
// if not running on an ec2 instance. // if not running on an ec2 instance.
Timeout: 5 * time.Second, Timeout: 1 * time.Second,
} }
// max number of retries on the client operation
cfg.MaxRetries = aws.Int(2)
}
if u, err := url.Parse(endpoint); err == nil {
// Remove path from the endpoint since it will be added by requests.
// This is an artifact of the SDK adding `/latest` to the endpoint for
// EC2 IMDS, but this is now moved to the operation definition.
u.Path = ""
u.RawPath = ""
endpoint = u.String()
} }
svc := &EC2Metadata{ svc := &EC2Metadata{
@ -72,6 +107,7 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
cfg, cfg,
metadata.ClientInfo{ metadata.ClientInfo{
ServiceName: ServiceName, ServiceName: ServiceName,
ServiceID: ServiceName,
Endpoint: endpoint, Endpoint: endpoint,
APIVersion: "latest", APIVersion: "latest",
}, },
@ -79,18 +115,35 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
), ),
} }
svc.Handlers.Unmarshal.PushBack(unmarshalHandler) // token provider instance
tp := newTokenProvider(svc, defaultTTL)
// NamedHandler for fetching token
svc.Handlers.Sign.PushBackNamed(request.NamedHandler{
Name: fetchTokenHandlerName,
Fn: tp.fetchTokenHandler,
})
// NamedHandler for enabling token provider
svc.Handlers.Complete.PushBackNamed(request.NamedHandler{
Name: enableTokenProviderHandlerName,
Fn: tp.enableTokenProviderHandler,
})
svc.Handlers.Unmarshal.PushBackNamed(unmarshalHandler)
svc.Handlers.UnmarshalError.PushBack(unmarshalError) svc.Handlers.UnmarshalError.PushBack(unmarshalError)
svc.Handlers.Validate.Clear() svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler) svc.Handlers.Validate.PushBack(validateEndpointHandler)
// Disable the EC2 Metadata service if the environment variable is set. // Disable the EC2 Metadata service if the environment variable is set.
// This shortcirctes the service's functionality to always fail to send // This short-circuits the service's functionality to always fail to send
// requests. // requests.
if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" { if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" {
svc.Handlers.Send.SwapNamed(request.NamedHandler{ svc.Handlers.Send.SwapNamed(request.NamedHandler{
Name: corehandlers.SendHandler.Name, Name: corehandlers.SendHandler.Name,
Fn: func(r *request.Request) { Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
Header: http.Header{},
}
r.Error = awserr.New( r.Error = awserr.New(
request.CanceledErrorCode, request.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var", "EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
@ -103,7 +156,6 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
for _, option := range opts { for _, option := range opts {
option(svc.Client) option(svc.Client)
} }
return svc return svc
} }
@ -115,30 +167,75 @@ type metadataOutput struct {
Content string Content string
} }
func unmarshalHandler(r *request.Request) { type tokenOutput struct {
Token string
TTL time.Duration
}
// unmarshal token handler is used to parse the response of a getToken operation
var unmarshalTokenHandler = request.NamedHandler{
Name: unmarshalTokenHandlerName,
Fn: func(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{} var b bytes.Buffer
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil { if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err) r.Error = awserr.NewRequestFailure(awserr.New(request.ErrCodeSerialization,
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}
v := r.HTTPResponse.Header.Get(ttlHeader)
data, ok := r.Data.(*tokenOutput)
if !ok {
return
}
data.Token = b.String()
// TTL is in seconds
i, err := strconv.ParseInt(v, 10, 64)
if err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(request.ParamFormatErrCode,
"unable to parse EC2 token TTL response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}
t := time.Duration(i) * time.Second
data.TTL = t
},
}
var unmarshalHandler = request.NamedHandler{
Name: unmarshalMetadataHandlerName,
Fn: func(r *request.Request) {
defer r.HTTPResponse.Body.Close()
var b bytes.Buffer
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(request.ErrCodeSerialization,
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
return return
} }
if data, ok := r.Data.(*metadataOutput); ok { if data, ok := r.Data.(*metadataOutput); ok {
data.Content = b.String() data.Content = b.String()
} }
},
} }
func unmarshalError(r *request.Request) { func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close() defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{} var b bytes.Buffer
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err) if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error response", err),
r.HTTPResponse.StatusCode, r.RequestID)
return return
} }
// Response body format is not consistent between metadata endpoints. // Response body format is not consistent between metadata endpoints.
// Grab the error message as a string and include that as the source error // Grab the error message as a string and include that as the source error
r.Error = awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String())) r.Error = awserr.NewRequestFailure(
awserr.New("EC2MetadataError", "failed to make EC2Metadata request\n"+b.String(), nil),
r.HTTPResponse.StatusCode, r.RequestID)
} }
func validateEndpointHandler(r *request.Request) { func validateEndpointHandler(r *request.Request) {

View file

@ -0,0 +1,93 @@
package ec2metadata
import (
"net/http"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
// A tokenProvider struct provides access to EC2Metadata client
// and atomic instance of a token, along with configuredTTL for it.
// tokenProvider also provides an atomic flag to disable the
// fetch token operation.
// The disabled member will use 0 as false, and 1 as true.
type tokenProvider struct {
client *EC2Metadata
token atomic.Value
configuredTTL time.Duration
disabled uint32
}
// A ec2Token struct helps use of token in EC2 Metadata service ops
type ec2Token struct {
token string
credentials.Expiry
}
// newTokenProvider provides a pointer to a tokenProvider instance
func newTokenProvider(c *EC2Metadata, duration time.Duration) *tokenProvider {
return &tokenProvider{client: c, configuredTTL: duration}
}
// fetchTokenHandler fetches token for EC2Metadata service client by default.
func (t *tokenProvider) fetchTokenHandler(r *request.Request) {
// short-circuits to insecure data flow if tokenProvider is disabled.
if v := atomic.LoadUint32(&t.disabled); v == 1 {
return
}
if ec2Token, ok := t.token.Load().(ec2Token); ok && !ec2Token.IsExpired() {
r.HTTPRequest.Header.Set(tokenHeader, ec2Token.token)
return
}
output, err := t.client.getToken(r.Context(), t.configuredTTL)
if err != nil {
// change the disabled flag on token provider to true,
// when error is request timeout error.
if requestFailureError, ok := err.(awserr.RequestFailure); ok {
switch requestFailureError.StatusCode() {
case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed:
atomic.StoreUint32(&t.disabled, 1)
case http.StatusBadRequest:
r.Error = requestFailureError
}
// Check if request timed out while waiting for response
if e, ok := requestFailureError.OrigErr().(awserr.Error); ok {
if e.Code() == request.ErrCodeRequestError {
atomic.StoreUint32(&t.disabled, 1)
}
}
}
return
}
newToken := ec2Token{
token: output.Token,
}
newToken.SetExpiration(time.Now().Add(output.TTL), ttlExpirationWindow)
t.token.Store(newToken)
// Inject token header to the request.
if ec2Token, ok := t.token.Load().(ec2Token); ok {
r.HTTPRequest.Header.Set(tokenHeader, ec2Token.token)
}
}
// enableTokenProviderHandler enables the token provider
func (t *tokenProvider) enableTokenProviderHandler(r *request.Request) {
// If the error code status is 401, we enable the token provider
if e, ok := r.Error.(awserr.RequestFailure); ok && e != nil &&
e.StatusCode() == http.StatusUnauthorized {
t.token.Store(ec2Token{})
atomic.StoreUint32(&t.disabled, 0)
}
}

View file

@ -81,42 +81,45 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
// Customization // Customization
for i := 0; i < len(ps); i++ { for i := 0; i < len(ps); i++ {
p := &ps[i] p := &ps[i]
custAddEC2Metadata(p) custRegionalS3(p)
custAddS3DualStack(p)
custRmIotDataService(p) custRmIotDataService(p)
custFixAppAutoscalingChina(p) custFixAppAutoscalingChina(p)
custFixAppAutoscalingUsGov(p)
} }
return ps, nil return ps, nil
} }
func custAddS3DualStack(p *partition) { func custRegionalS3(p *partition) {
if p.ID != "aws" { if p.ID != "aws" {
return return
} }
s, ok := p.Services["s3"] service, ok := p.Services["s3"]
if !ok { if !ok {
return return
} }
s.Defaults.HasDualStack = boxedTrue const awsGlobal = "aws-global"
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}" const usEast1 = "us-east-1"
p.Services["s3"] = s // If global endpoint already exists no customization needed.
} if _, ok := service.Endpoints[endpointKey{Region: awsGlobal}]; ok {
return
}
func custAddEC2Metadata(p *partition) { service.PartitionEndpoint = awsGlobal
p.Services["ec2metadata"] = service{ if _, ok := service.Endpoints[endpointKey{Region: usEast1}]; !ok {
IsRegionalized: boxedFalse, service.Endpoints[endpointKey{Region: usEast1}] = endpoint{}
PartitionEndpoint: "aws-global", }
Endpoints: endpoints{ service.Endpoints[endpointKey{Region: awsGlobal}] = endpoint{
"aws-global": endpoint{ Hostname: "s3.amazonaws.com",
Hostname: "169.254.169.254/latest", CredentialScope: credentialScope{
Protocols: []string{"http"}, Region: usEast1,
},
}, },
} }
p.Services["s3"] = service
} }
func custRmIotDataService(p *partition) { func custRmIotDataService(p *partition) {
@ -135,12 +138,47 @@ func custFixAppAutoscalingChina(p *partition) {
} }
const expectHostname = `autoscaling.{region}.amazonaws.com` const expectHostname = `autoscaling.{region}.amazonaws.com`
if e, a := s.Defaults.Hostname, expectHostname; e != a { serviceDefault := s.Defaults[defaultKey{}]
if e, a := expectHostname, serviceDefault.Hostname; e != a {
fmt.Printf("custFixAppAutoscalingChina: ignoring customization, expected %s, got %s\n", e, a) fmt.Printf("custFixAppAutoscalingChina: ignoring customization, expected %s, got %s\n", e, a)
return return
} }
serviceDefault.Hostname = expectHostname + ".cn"
s.Defaults[defaultKey{}] = serviceDefault
p.Services[serviceName] = s
}
func custFixAppAutoscalingUsGov(p *partition) {
if p.ID != "aws-us-gov" {
return
}
const serviceName = "application-autoscaling"
s, ok := p.Services[serviceName]
if !ok {
return
}
serviceDefault := s.Defaults[defaultKey{}]
if a := serviceDefault.CredentialScope.Service; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty credential scope service, got %s\n", a)
return
}
if a := serviceDefault.Hostname; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty hostname, got %s\n", a)
return
}
serviceDefault.CredentialScope.Service = "application-autoscaling"
serviceDefault.Hostname = "autoscaling.{region}.amazonaws.com"
if s.Defaults == nil {
s.Defaults = make(endpointDefaults)
}
s.Defaults[defaultKey{}] = serviceDefault
s.Defaults.Hostname = expectHostname + ".cn"
p.Services[serviceName] = s p.Services[serviceName] = s
} }

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,141 @@
package endpoints
// Service identifiers
//
// Deprecated: Use client package's EndpointsID value instead of these
// ServiceIDs. These IDs are not maintained, and are out of date.
const (
A4bServiceID = "a4b" // A4b.
AcmServiceID = "acm" // Acm.
AcmPcaServiceID = "acm-pca" // AcmPca.
ApiMediatailorServiceID = "api.mediatailor" // ApiMediatailor.
ApiPricingServiceID = "api.pricing" // ApiPricing.
ApiSagemakerServiceID = "api.sagemaker" // ApiSagemaker.
ApigatewayServiceID = "apigateway" // Apigateway.
ApplicationAutoscalingServiceID = "application-autoscaling" // ApplicationAutoscaling.
Appstream2ServiceID = "appstream2" // Appstream2.
AppsyncServiceID = "appsync" // Appsync.
AthenaServiceID = "athena" // Athena.
AutoscalingServiceID = "autoscaling" // Autoscaling.
AutoscalingPlansServiceID = "autoscaling-plans" // AutoscalingPlans.
BatchServiceID = "batch" // Batch.
BudgetsServiceID = "budgets" // Budgets.
CeServiceID = "ce" // Ce.
ChimeServiceID = "chime" // Chime.
Cloud9ServiceID = "cloud9" // Cloud9.
ClouddirectoryServiceID = "clouddirectory" // Clouddirectory.
CloudformationServiceID = "cloudformation" // Cloudformation.
CloudfrontServiceID = "cloudfront" // Cloudfront.
CloudhsmServiceID = "cloudhsm" // Cloudhsm.
Cloudhsmv2ServiceID = "cloudhsmv2" // Cloudhsmv2.
CloudsearchServiceID = "cloudsearch" // Cloudsearch.
CloudtrailServiceID = "cloudtrail" // Cloudtrail.
CodebuildServiceID = "codebuild" // Codebuild.
CodecommitServiceID = "codecommit" // Codecommit.
CodedeployServiceID = "codedeploy" // Codedeploy.
CodepipelineServiceID = "codepipeline" // Codepipeline.
CodestarServiceID = "codestar" // Codestar.
CognitoIdentityServiceID = "cognito-identity" // CognitoIdentity.
CognitoIdpServiceID = "cognito-idp" // CognitoIdp.
CognitoSyncServiceID = "cognito-sync" // CognitoSync.
ComprehendServiceID = "comprehend" // Comprehend.
ConfigServiceID = "config" // Config.
CurServiceID = "cur" // Cur.
DatapipelineServiceID = "datapipeline" // Datapipeline.
DaxServiceID = "dax" // Dax.
DevicefarmServiceID = "devicefarm" // Devicefarm.
DirectconnectServiceID = "directconnect" // Directconnect.
DiscoveryServiceID = "discovery" // Discovery.
DmsServiceID = "dms" // Dms.
DsServiceID = "ds" // Ds.
DynamodbServiceID = "dynamodb" // Dynamodb.
Ec2ServiceID = "ec2" // Ec2.
Ec2metadataServiceID = "ec2metadata" // Ec2metadata.
EcrServiceID = "ecr" // Ecr.
EcsServiceID = "ecs" // Ecs.
ElasticacheServiceID = "elasticache" // Elasticache.
ElasticbeanstalkServiceID = "elasticbeanstalk" // Elasticbeanstalk.
ElasticfilesystemServiceID = "elasticfilesystem" // Elasticfilesystem.
ElasticloadbalancingServiceID = "elasticloadbalancing" // Elasticloadbalancing.
ElasticmapreduceServiceID = "elasticmapreduce" // Elasticmapreduce.
ElastictranscoderServiceID = "elastictranscoder" // Elastictranscoder.
EmailServiceID = "email" // Email.
EntitlementMarketplaceServiceID = "entitlement.marketplace" // EntitlementMarketplace.
EsServiceID = "es" // Es.
EventsServiceID = "events" // Events.
FirehoseServiceID = "firehose" // Firehose.
FmsServiceID = "fms" // Fms.
GameliftServiceID = "gamelift" // Gamelift.
GlacierServiceID = "glacier" // Glacier.
GlueServiceID = "glue" // Glue.
GreengrassServiceID = "greengrass" // Greengrass.
GuarddutyServiceID = "guardduty" // Guardduty.
HealthServiceID = "health" // Health.
IamServiceID = "iam" // Iam.
ImportexportServiceID = "importexport" // Importexport.
InspectorServiceID = "inspector" // Inspector.
IotServiceID = "iot" // Iot.
IotanalyticsServiceID = "iotanalytics" // Iotanalytics.
KinesisServiceID = "kinesis" // Kinesis.
KinesisanalyticsServiceID = "kinesisanalytics" // Kinesisanalytics.
KinesisvideoServiceID = "kinesisvideo" // Kinesisvideo.
KmsServiceID = "kms" // Kms.
LambdaServiceID = "lambda" // Lambda.
LightsailServiceID = "lightsail" // Lightsail.
LogsServiceID = "logs" // Logs.
MachinelearningServiceID = "machinelearning" // Machinelearning.
MarketplacecommerceanalyticsServiceID = "marketplacecommerceanalytics" // Marketplacecommerceanalytics.
MediaconvertServiceID = "mediaconvert" // Mediaconvert.
MedialiveServiceID = "medialive" // Medialive.
MediapackageServiceID = "mediapackage" // Mediapackage.
MediastoreServiceID = "mediastore" // Mediastore.
MeteringMarketplaceServiceID = "metering.marketplace" // MeteringMarketplace.
MghServiceID = "mgh" // Mgh.
MobileanalyticsServiceID = "mobileanalytics" // Mobileanalytics.
ModelsLexServiceID = "models.lex" // ModelsLex.
MonitoringServiceID = "monitoring" // Monitoring.
MturkRequesterServiceID = "mturk-requester" // MturkRequester.
NeptuneServiceID = "neptune" // Neptune.
OpsworksServiceID = "opsworks" // Opsworks.
OpsworksCmServiceID = "opsworks-cm" // OpsworksCm.
OrganizationsServiceID = "organizations" // Organizations.
PinpointServiceID = "pinpoint" // Pinpoint.
PollyServiceID = "polly" // Polly.
RdsServiceID = "rds" // Rds.
RedshiftServiceID = "redshift" // Redshift.
RekognitionServiceID = "rekognition" // Rekognition.
ResourceGroupsServiceID = "resource-groups" // ResourceGroups.
Route53ServiceID = "route53" // Route53.
Route53domainsServiceID = "route53domains" // Route53domains.
RuntimeLexServiceID = "runtime.lex" // RuntimeLex.
RuntimeSagemakerServiceID = "runtime.sagemaker" // RuntimeSagemaker.
S3ServiceID = "s3" // S3.
S3ControlServiceID = "s3-control" // S3Control.
SagemakerServiceID = "api.sagemaker" // Sagemaker.
SdbServiceID = "sdb" // Sdb.
SecretsmanagerServiceID = "secretsmanager" // Secretsmanager.
ServerlessrepoServiceID = "serverlessrepo" // Serverlessrepo.
ServicecatalogServiceID = "servicecatalog" // Servicecatalog.
ServicediscoveryServiceID = "servicediscovery" // Servicediscovery.
ShieldServiceID = "shield" // Shield.
SmsServiceID = "sms" // Sms.
SnowballServiceID = "snowball" // Snowball.
SnsServiceID = "sns" // Sns.
SqsServiceID = "sqs" // Sqs.
SsmServiceID = "ssm" // Ssm.
StatesServiceID = "states" // States.
StoragegatewayServiceID = "storagegateway" // Storagegateway.
StreamsDynamodbServiceID = "streams.dynamodb" // StreamsDynamodb.
StsServiceID = "sts" // Sts.
SupportServiceID = "support" // Support.
SwfServiceID = "swf" // Swf.
TaggingServiceID = "tagging" // Tagging.
TransferServiceID = "transfer" // Transfer.
TranslateServiceID = "translate" // Translate.
WafServiceID = "waf" // Waf.
WafRegionalServiceID = "waf-regional" // WafRegional.
WorkdocsServiceID = "workdocs" // Workdocs.
WorkmailServiceID = "workmail" // Workmail.
WorkspacesServiceID = "workspaces" // Workspaces.
XrayServiceID = "xray" // Xray.
)

View file

@ -3,10 +3,46 @@ package endpoints
import ( import (
"fmt" "fmt"
"regexp" "regexp"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
) )
// A Logger is a minimalistic interface for the SDK to log messages to.
type Logger interface {
Log(...interface{})
}
// DualStackEndpointState is a constant to describe the dual-stack endpoint resolution
// behavior.
type DualStackEndpointState uint
const (
// DualStackEndpointStateUnset is the default value behavior for dual-stack endpoint
// resolution.
DualStackEndpointStateUnset DualStackEndpointState = iota
// DualStackEndpointStateEnabled enable dual-stack endpoint resolution for endpoints.
DualStackEndpointStateEnabled
// DualStackEndpointStateDisabled disables dual-stack endpoint resolution for endpoints.
DualStackEndpointStateDisabled
)
// FIPSEndpointState is a constant to describe the FIPS endpoint resolution behavior.
type FIPSEndpointState uint
const (
// FIPSEndpointStateUnset is the default value behavior for FIPS endpoint resolution.
FIPSEndpointStateUnset FIPSEndpointState = iota
// FIPSEndpointStateEnabled enables FIPS endpoint resolution for service endpoints.
FIPSEndpointStateEnabled
// FIPSEndpointStateDisabled disables FIPS endpoint resolution for endpoints.
FIPSEndpointStateDisabled
)
// Options provide the configuration needed to direct how the // Options provide the configuration needed to direct how the
// endpoints will be resolved. // endpoints will be resolved.
type Options struct { type Options struct {
@ -20,8 +56,19 @@ type Options struct {
// be returned. This endpoint may not be valid. If StrictMatching is // be returned. This endpoint may not be valid. If StrictMatching is
// enabled only services that are known to support dualstack will return // enabled only services that are known to support dualstack will return
// dualstack endpoints. // dualstack endpoints.
//
// Deprecated: This option will continue to function for S3 and S3 Control for backwards compatibility.
// UseDualStackEndpoint should be used to enable usage of a service's dual-stack endpoint for all service clients
// moving forward. For S3 and S3 Control, when UseDualStackEndpoint is set to a non-zero value it takes higher
// precedence then this option.
UseDualStack bool UseDualStack bool
// Sets the resolver to resolve a dual-stack endpoint for the service.
UseDualStackEndpoint DualStackEndpointState
// UseFIPSEndpoint specifies the resolver must resolve a FIPS endpoint.
UseFIPSEndpoint FIPSEndpointState
// Enables strict matching of services and regions resolved endpoints. // Enables strict matching of services and regions resolved endpoints.
// If the partition doesn't enumerate the exact service and region an // If the partition doesn't enumerate the exact service and region an
// error will be returned. This option will prevent returning endpoints // error will be returned. This option will prevent returning endpoints
@ -35,7 +82,7 @@ type Options struct {
// //
// If resolving an endpoint on the partition list the provided region will // If resolving an endpoint on the partition list the provided region will
// be used to determine which partition's domain name pattern to the service // be used to determine which partition's domain name pattern to the service
// endpoint ID with. If both the service and region are unkonwn and resolving // endpoint ID with. If both the service and region are unknown and resolving
// the endpoint on partition list an UnknownEndpointError error will be returned. // the endpoint on partition list an UnknownEndpointError error will be returned.
// //
// If resolving and endpoint on a partition specific resolver that partition's // If resolving and endpoint on a partition specific resolver that partition's
@ -46,6 +93,162 @@ type Options struct {
// //
// This option is ignored if StrictMatching is enabled. // This option is ignored if StrictMatching is enabled.
ResolveUnknownService bool ResolveUnknownService bool
// Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6)
EC2MetadataEndpointMode EC2IMDSEndpointModeState
// STS Regional Endpoint flag helps with resolving the STS endpoint
STSRegionalEndpoint STSRegionalEndpoint
// S3 Regional Endpoint flag helps with resolving the S3 endpoint
S3UsEast1RegionalEndpoint S3UsEast1RegionalEndpoint
// ResolvedRegion is the resolved region string. If provided (non-zero length) it takes priority
// over the region name passed to the ResolveEndpoint call.
ResolvedRegion string
// Logger is the logger that will be used to log messages.
Logger Logger
// Determines whether logging of deprecated endpoints usage is enabled.
LogDeprecated bool
}
func (o Options) getEndpointVariant(service string) (v endpointVariant) {
const s3 = "s3"
const s3Control = "s3-control"
if (o.UseDualStackEndpoint == DualStackEndpointStateEnabled) ||
((service == s3 || service == s3Control) && (o.UseDualStackEndpoint == DualStackEndpointStateUnset && o.UseDualStack)) {
v |= dualStackVariant
}
if o.UseFIPSEndpoint == FIPSEndpointStateEnabled {
v |= fipsVariant
}
return v
}
// EC2IMDSEndpointModeState is an enum configuration variable describing the client endpoint mode.
type EC2IMDSEndpointModeState uint
// Enumeration values for EC2IMDSEndpointModeState
const (
EC2IMDSEndpointModeStateUnset EC2IMDSEndpointModeState = iota
EC2IMDSEndpointModeStateIPv4
EC2IMDSEndpointModeStateIPv6
)
// SetFromString sets the EC2IMDSEndpointModeState based on the provided string value. Unknown values will default to EC2IMDSEndpointModeStateUnset
func (e *EC2IMDSEndpointModeState) SetFromString(v string) error {
v = strings.TrimSpace(v)
switch {
case len(v) == 0:
*e = EC2IMDSEndpointModeStateUnset
case strings.EqualFold(v, "IPv6"):
*e = EC2IMDSEndpointModeStateIPv6
case strings.EqualFold(v, "IPv4"):
*e = EC2IMDSEndpointModeStateIPv4
default:
return fmt.Errorf("unknown EC2 IMDS endpoint mode, must be either IPv6 or IPv4")
}
return nil
}
// STSRegionalEndpoint is an enum for the states of the STS Regional Endpoint
// options.
type STSRegionalEndpoint int
func (e STSRegionalEndpoint) String() string {
switch e {
case LegacySTSEndpoint:
return "legacy"
case RegionalSTSEndpoint:
return "regional"
case UnsetSTSEndpoint:
return ""
default:
return "unknown"
}
}
const (
// UnsetSTSEndpoint represents that STS Regional Endpoint flag is not specified.
UnsetSTSEndpoint STSRegionalEndpoint = iota
// LegacySTSEndpoint represents when STS Regional Endpoint flag is specified
// to use legacy endpoints.
LegacySTSEndpoint
// RegionalSTSEndpoint represents when STS Regional Endpoint flag is specified
// to use regional endpoints.
RegionalSTSEndpoint
)
// GetSTSRegionalEndpoint function returns the STSRegionalEndpointFlag based
// on the input string provided in env config or shared config by the user.
//
// `legacy`, `regional` are the only case-insensitive valid strings for
// resolving the STS regional Endpoint flag.
func GetSTSRegionalEndpoint(s string) (STSRegionalEndpoint, error) {
switch {
case strings.EqualFold(s, "legacy"):
return LegacySTSEndpoint, nil
case strings.EqualFold(s, "regional"):
return RegionalSTSEndpoint, nil
default:
return UnsetSTSEndpoint, fmt.Errorf("unable to resolve the value of STSRegionalEndpoint for %v", s)
}
}
// S3UsEast1RegionalEndpoint is an enum for the states of the S3 us-east-1
// Regional Endpoint options.
type S3UsEast1RegionalEndpoint int
func (e S3UsEast1RegionalEndpoint) String() string {
switch e {
case LegacyS3UsEast1Endpoint:
return "legacy"
case RegionalS3UsEast1Endpoint:
return "regional"
case UnsetS3UsEast1Endpoint:
return ""
default:
return "unknown"
}
}
const (
// UnsetS3UsEast1Endpoint represents that S3 Regional Endpoint flag is not
// specified.
UnsetS3UsEast1Endpoint S3UsEast1RegionalEndpoint = iota
// LegacyS3UsEast1Endpoint represents when S3 Regional Endpoint flag is
// specified to use legacy endpoints.
LegacyS3UsEast1Endpoint
// RegionalS3UsEast1Endpoint represents when S3 Regional Endpoint flag is
// specified to use regional endpoints.
RegionalS3UsEast1Endpoint
)
// GetS3UsEast1RegionalEndpoint function returns the S3UsEast1RegionalEndpointFlag based
// on the input string provided in env config or shared config by the user.
//
// `legacy`, `regional` are the only case-insensitive valid strings for
// resolving the S3 regional Endpoint flag.
func GetS3UsEast1RegionalEndpoint(s string) (S3UsEast1RegionalEndpoint, error) {
switch {
case strings.EqualFold(s, "legacy"):
return LegacyS3UsEast1Endpoint, nil
case strings.EqualFold(s, "regional"):
return RegionalS3UsEast1Endpoint, nil
default:
return UnsetS3UsEast1Endpoint,
fmt.Errorf("unable to resolve the value of S3UsEast1RegionalEndpoint for %v", s)
}
} }
// Set combines all of the option functions together. // Set combines all of the option functions together.
@ -63,10 +266,25 @@ func DisableSSLOption(o *Options) {
// UseDualStackOption sets the UseDualStack option. Can be used as a functional // UseDualStackOption sets the UseDualStack option. Can be used as a functional
// option when resolving endpoints. // option when resolving endpoints.
//
// Deprecated: UseDualStackEndpointOption should be used to enable usage of a service's dual-stack endpoint.
// When DualStackEndpointState is set to a non-zero value it takes higher precedence then this option.
func UseDualStackOption(o *Options) { func UseDualStackOption(o *Options) {
o.UseDualStack = true o.UseDualStack = true
} }
// UseDualStackEndpointOption sets the UseDualStackEndpoint option to enabled. Can be used as a functional
// option when resolving endpoints.
func UseDualStackEndpointOption(o *Options) {
o.UseDualStackEndpoint = DualStackEndpointStateEnabled
}
// UseFIPSEndpointOption sets the UseFIPSEndpoint option to enabled. Can be used as a functional
// option when resolving endpoints.
func UseFIPSEndpointOption(o *Options) {
o.UseFIPSEndpoint = FIPSEndpointStateEnabled
}
// StrictMatchingOption sets the StrictMatching option. Can be used as a functional // StrictMatchingOption sets the StrictMatching option. Can be used as a functional
// option when resolving endpoints. // option when resolving endpoints.
func StrictMatchingOption(o *Options) { func StrictMatchingOption(o *Options) {
@ -79,6 +297,12 @@ func ResolveUnknownServiceOption(o *Options) {
o.ResolveUnknownService = true o.ResolveUnknownService = true
} }
// STSRegionalEndpointOption enables the STS endpoint resolver behavior to resolve
// STS endpoint to their regional endpoint, instead of the global endpoint.
func STSRegionalEndpointOption(o *Options) {
o.STSRegionalEndpoint = RegionalSTSEndpoint
}
// A Resolver provides the interface for functionality to resolve endpoints. // A Resolver provides the interface for functionality to resolve endpoints.
// The build in Partition and DefaultResolver return value satisfy this interface. // The build in Partition and DefaultResolver return value satisfy this interface.
type Resolver interface { type Resolver interface {
@ -138,7 +362,7 @@ func RegionsForService(ps []Partition, partitionID, serviceID string) (map[strin
if p.ID() != partitionID { if p.ID() != partitionID {
continue continue
} }
if _, ok := p.p.Services[serviceID]; !ok { if _, ok := p.p.Services[serviceID]; !(ok || serviceID == Ec2metadataServiceID) {
break break
} }
@ -170,10 +394,13 @@ func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) {
// A Partition provides the ability to enumerate the partition's regions // A Partition provides the ability to enumerate the partition's regions
// and services. // and services.
type Partition struct { type Partition struct {
id string id, dnsSuffix string
p *partition p *partition
} }
// DNSSuffix returns the base domain name of the partition.
func (p Partition) DNSSuffix() string { return p.dnsSuffix }
// ID returns the identifier of the partition. // ID returns the identifier of the partition.
func (p Partition) ID() string { return p.id } func (p Partition) ID() string { return p.id }
@ -191,7 +418,7 @@ func (p Partition) ID() string { return p.id }
// require the provided service and region to be known by the partition. // require the provided service and region to be known by the partition.
// If the endpoint cannot be strictly resolved an error will be returned. This // If the endpoint cannot be strictly resolved an error will be returned. This
// mode is useful to ensure the endpoint resolved is valid. Without // mode is useful to ensure the endpoint resolved is valid. Without
// StrictMatching enabled the endpoint returned my look valid but may not work. // StrictMatching enabled the endpoint returned may look valid but may not work.
// StrictMatching requires the SDK to be updated if you want to take advantage // StrictMatching requires the SDK to be updated if you want to take advantage
// of new regions and services expansions. // of new regions and services expansions.
// //
@ -205,7 +432,7 @@ func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (
// Regions returns a map of Regions indexed by their ID. This is useful for // Regions returns a map of Regions indexed by their ID. This is useful for
// enumerating over the regions in a partition. // enumerating over the regions in a partition.
func (p Partition) Regions() map[string]Region { func (p Partition) Regions() map[string]Region {
rs := map[string]Region{} rs := make(map[string]Region, len(p.p.Regions))
for id, r := range p.p.Regions { for id, r := range p.p.Regions {
rs[id] = Region{ rs[id] = Region{
id: id, id: id,
@ -220,7 +447,8 @@ func (p Partition) Regions() map[string]Region {
// Services returns a map of Service indexed by their ID. This is useful for // Services returns a map of Service indexed by their ID. This is useful for
// enumerating over the services in a partition. // enumerating over the services in a partition.
func (p Partition) Services() map[string]Service { func (p Partition) Services() map[string]Service {
ss := map[string]Service{} ss := make(map[string]Service, len(p.p.Services))
for id := range p.p.Services { for id := range p.p.Services {
ss[id] = Service{ ss[id] = Service{
id: id, id: id,
@ -228,6 +456,15 @@ func (p Partition) Services() map[string]Service {
} }
} }
// Since we have removed the customization that injected this into the model
// we still need to pretend that this is a modeled service.
if _, ok := ss[Ec2metadataServiceID]; !ok {
ss[Ec2metadataServiceID] = Service{
id: Ec2metadataServiceID,
p: p.p,
}
}
return ss return ss
} }
@ -255,7 +492,7 @@ func (r Region) ResolveEndpoint(service string, opts ...func(*Options)) (Resolve
func (r Region) Services() map[string]Service { func (r Region) Services() map[string]Service {
ss := map[string]Service{} ss := map[string]Service{}
for id, s := range r.p.Services { for id, s := range r.p.Services {
if _, ok := s.Endpoints[r.id]; ok { if _, ok := s.Endpoints[endpointKey{Region: r.id}]; ok {
ss[id] = Service{ ss[id] = Service{
id: id, id: id,
p: r.p, p: r.p,
@ -288,10 +525,24 @@ func (s Service) ResolveEndpoint(region string, opts ...func(*Options)) (Resolve
// an URL that can be resolved to a instance of a service. // an URL that can be resolved to a instance of a service.
func (s Service) Regions() map[string]Region { func (s Service) Regions() map[string]Region {
rs := map[string]Region{} rs := map[string]Region{}
for id := range s.p.Services[s.id].Endpoints {
if r, ok := s.p.Regions[id]; ok { service, ok := s.p.Services[s.id]
rs[id] = Region{
id: id, // Since ec2metadata customization has been removed we need to check
// if it was defined in non-standard endpoints.json file. If it's not
// then we can return the empty map as there is no regional-endpoints for IMDS.
// Otherwise, we iterate need to iterate the non-standard model.
if s.id == Ec2metadataServiceID && !ok {
return rs
}
for id := range service.Endpoints {
if id.Variant != 0 {
continue
}
if r, ok := s.p.Regions[id.Region]; ok {
rs[id.Region] = Region{
id: id.Region,
desc: r.Description, desc: r.Description,
p: s.p, p: s.p,
} }
@ -307,10 +558,13 @@ func (s Service) Regions() map[string]Region {
// A region is the AWS region the service exists in. Whereas a Endpoint is // A region is the AWS region the service exists in. Whereas a Endpoint is
// an URL that can be resolved to a instance of a service. // an URL that can be resolved to a instance of a service.
func (s Service) Endpoints() map[string]Endpoint { func (s Service) Endpoints() map[string]Endpoint {
es := map[string]Endpoint{} es := make(map[string]Endpoint, len(s.p.Services[s.id].Endpoints))
for id := range s.p.Services[s.id].Endpoints { for id := range s.p.Services[s.id].Endpoints {
es[id] = Endpoint{ if id.Variant != 0 {
id: id, continue
}
es[id.Region] = Endpoint{
id: id.Region,
serviceID: s.id, serviceID: s.id,
p: s.p, p: s.p,
} }
@ -347,6 +601,9 @@ type ResolvedEndpoint struct {
// The endpoint URL // The endpoint URL
URL string URL string
// The endpoint partition
PartitionID string
// The region that should be used for signing requests. // The region that should be used for signing requests.
SigningRegion string SigningRegion string

View file

@ -0,0 +1,24 @@
package endpoints
var legacyGlobalRegions = map[string]map[string]struct{}{
"sts": {
"ap-northeast-1": {},
"ap-south-1": {},
"ap-southeast-1": {},
"ap-southeast-2": {},
"ca-central-1": {},
"eu-central-1": {},
"eu-north-1": {},
"eu-west-1": {},
"eu-west-2": {},
"eu-west-3": {},
"sa-east-1": {},
"us-east-1": {},
"us-east-2": {},
"us-west-1": {},
"us-west-2": {},
},
"s3": {
"us-east-1": {},
},
}

View file

@ -1,20 +1,60 @@
package endpoints package endpoints
import ( import (
"encoding/json"
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
) )
const (
ec2MetadataEndpointIPv6 = "http://[fd00:ec2::254]/latest"
ec2MetadataEndpointIPv4 = "http://169.254.169.254/latest"
)
const dnsSuffixTemplateKey = "{dnsSuffix}"
// defaultKey is a compound map key of a variant and other values.
type defaultKey struct {
Variant endpointVariant
ServiceVariant serviceVariant
}
// endpointKey is a compound map key of a region and associated variant value.
type endpointKey struct {
Region string
Variant endpointVariant
}
// endpointVariant is a bit field to describe the endpoints attributes.
type endpointVariant uint64
// serviceVariant is a bit field to describe the service endpoint attributes.
type serviceVariant uint64
const (
// fipsVariant indicates that the endpoint is FIPS capable.
fipsVariant endpointVariant = 1 << (64 - 1 - iota)
// dualStackVariant indicates that the endpoint is DualStack capable.
dualStackVariant
)
var regionValidationRegex = regexp.MustCompile(`^[[:alnum:]]([[:alnum:]\-]*[[:alnum:]])?$`)
type partitions []partition type partitions []partition
func (ps partitions) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) { func (ps partitions) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
var opt Options var opt Options
opt.Set(opts...) opt.Set(opts...)
if len(opt.ResolvedRegion) > 0 {
region = opt.ResolvedRegion
}
for i := 0; i < len(ps); i++ { for i := 0; i < len(ps); i++ {
if !ps[i].canResolveEndpoint(service, region, opt.StrictMatching) { if !ps[i].canResolveEndpoint(service, region, opt) {
continue continue
} }
@ -42,56 +82,237 @@ func (ps partitions) Partitions() []Partition {
return parts return parts
} }
type endpointWithVariants struct {
endpoint
Variants []endpointWithTags `json:"variants"`
}
type endpointWithTags struct {
endpoint
Tags []string `json:"tags"`
}
type endpointDefaults map[defaultKey]endpoint
func (p *endpointDefaults) UnmarshalJSON(data []byte) error {
if *p == nil {
*p = make(endpointDefaults)
}
var e endpointWithVariants
if err := json.Unmarshal(data, &e); err != nil {
return err
}
(*p)[defaultKey{Variant: 0}] = e.endpoint
e.Hostname = ""
e.DNSSuffix = ""
for _, variant := range e.Variants {
endpointVariant, unknown := parseVariantTags(variant.Tags)
if unknown {
continue
}
var ve endpoint
ve.mergeIn(e.endpoint)
ve.mergeIn(variant.endpoint)
(*p)[defaultKey{Variant: endpointVariant}] = ve
}
return nil
}
func parseVariantTags(tags []string) (ev endpointVariant, unknown bool) {
if len(tags) == 0 {
unknown = true
return
}
for _, tag := range tags {
switch {
case strings.EqualFold("fips", tag):
ev |= fipsVariant
case strings.EqualFold("dualstack", tag):
ev |= dualStackVariant
default:
unknown = true
}
}
return ev, unknown
}
type partition struct { type partition struct {
ID string `json:"partition"` ID string `json:"partition"`
Name string `json:"partitionName"` Name string `json:"partitionName"`
DNSSuffix string `json:"dnsSuffix"` DNSSuffix string `json:"dnsSuffix"`
RegionRegex regionRegex `json:"regionRegex"` RegionRegex regionRegex `json:"regionRegex"`
Defaults endpoint `json:"defaults"` Defaults endpointDefaults `json:"defaults"`
Regions regions `json:"regions"` Regions regions `json:"regions"`
Services services `json:"services"` Services services `json:"services"`
} }
func (p partition) Partition() Partition { func (p partition) Partition() Partition {
return Partition{ return Partition{
dnsSuffix: p.DNSSuffix,
id: p.ID, id: p.ID,
p: &p, p: &p,
} }
} }
func (p partition) canResolveEndpoint(service, region string, strictMatch bool) bool { func (p partition) canResolveEndpoint(service, region string, options Options) bool {
s, hasService := p.Services[service] s, hasService := p.Services[service]
_, hasEndpoint := s.Endpoints[region] _, hasEndpoint := s.Endpoints[endpointKey{
Region: region,
Variant: options.getEndpointVariant(service),
}]
if hasEndpoint && hasService { if hasEndpoint && hasService {
return true return true
} }
if strictMatch { if options.StrictMatching {
return false return false
} }
return p.RegionRegex.MatchString(region) return p.RegionRegex.MatchString(region)
} }
func allowLegacyEmptyRegion(service string) bool {
legacy := map[string]struct{}{
"budgets": {},
"ce": {},
"chime": {},
"cloudfront": {},
"ec2metadata": {},
"iam": {},
"importexport": {},
"organizations": {},
"route53": {},
"sts": {},
"support": {},
"waf": {},
}
_, allowed := legacy[service]
return allowed
}
func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (resolved ResolvedEndpoint, err error) { func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (resolved ResolvedEndpoint, err error) {
var opt Options var opt Options
opt.Set(opts...) opt.Set(opts...)
if len(opt.ResolvedRegion) > 0 {
region = opt.ResolvedRegion
}
s, hasService := p.Services[service] s, hasService := p.Services[service]
if !(hasService || opt.ResolveUnknownService) {
if service == Ec2metadataServiceID && !hasService {
endpoint := getEC2MetadataEndpoint(p.ID, service, opt.EC2MetadataEndpointMode)
return endpoint, nil
}
if len(service) == 0 || !(hasService || opt.ResolveUnknownService) {
// Only return error if the resolver will not fallback to creating // Only return error if the resolver will not fallback to creating
// endpoint based on service endpoint ID passed in. // endpoint based on service endpoint ID passed in.
return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services)) return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services))
} }
e, hasEndpoint := s.endpointForRegion(region) if len(region) == 0 && allowLegacyEmptyRegion(service) && len(s.PartitionEndpoint) != 0 {
if !hasEndpoint && opt.StrictMatching { region = s.PartitionEndpoint
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints))
} }
defs := []endpoint{p.Defaults, s.Defaults} if r, ok := isLegacyGlobalRegion(service, region, opt); ok {
return e.resolve(service, region, p.DNSSuffix, defs, opt), nil region = r
}
variant := opt.getEndpointVariant(service)
endpoints := s.Endpoints
serviceDefaults, hasServiceDefault := s.Defaults[defaultKey{Variant: variant}]
// If we searched for a variant which may have no explicit service defaults,
// then we need to inherit the standard service defaults except the hostname and dnsSuffix
if variant != 0 && !hasServiceDefault {
serviceDefaults = s.Defaults[defaultKey{}]
serviceDefaults.Hostname = ""
serviceDefaults.DNSSuffix = ""
}
partitionDefaults, hasPartitionDefault := p.Defaults[defaultKey{Variant: variant}]
var dnsSuffix string
if len(serviceDefaults.DNSSuffix) > 0 {
dnsSuffix = serviceDefaults.DNSSuffix
} else if variant == 0 {
// For legacy reasons the partition dnsSuffix is not in the defaults, so if we looked for
// a non-variant endpoint then we need to set the dnsSuffix.
dnsSuffix = p.DNSSuffix
}
noDefaults := !hasServiceDefault && !hasPartitionDefault
e, hasEndpoint := s.endpointForRegion(region, endpoints, variant)
if len(region) == 0 || (!hasEndpoint && (opt.StrictMatching || noDefaults)) {
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(endpoints, variant))
}
defs := []endpoint{partitionDefaults, serviceDefaults}
return e.resolve(service, p.ID, region, dnsSuffixTemplateKey, dnsSuffix, defs, opt)
}
func getEC2MetadataEndpoint(partitionID, service string, mode EC2IMDSEndpointModeState) ResolvedEndpoint {
switch mode {
case EC2IMDSEndpointModeStateIPv6:
return ResolvedEndpoint{
URL: ec2MetadataEndpointIPv6,
PartitionID: partitionID,
SigningRegion: "aws-global",
SigningName: service,
SigningNameDerived: true,
SigningMethod: "v4",
}
case EC2IMDSEndpointModeStateIPv4:
fallthrough
default:
return ResolvedEndpoint{
URL: ec2MetadataEndpointIPv4,
PartitionID: partitionID,
SigningRegion: "aws-global",
SigningName: service,
SigningNameDerived: true,
SigningMethod: "v4",
}
}
}
func isLegacyGlobalRegion(service string, region string, opt Options) (string, bool) {
if opt.getEndpointVariant(service) != 0 {
return "", false
}
const (
sts = "sts"
s3 = "s3"
awsGlobal = "aws-global"
)
switch {
case service == sts && opt.STSRegionalEndpoint == RegionalSTSEndpoint:
return region, false
case service == s3 && opt.S3UsEast1RegionalEndpoint == RegionalS3UsEast1Endpoint:
return region, false
default:
if _, ok := legacyGlobalRegions[service][region]; ok {
return awsGlobal, true
}
}
return region, false
} }
func serviceList(ss services) []string { func serviceList(ss services) []string {
@ -101,10 +322,13 @@ func serviceList(ss services) []string {
} }
return list return list
} }
func endpointList(es endpoints) []string { func endpointList(es serviceEndpoints, variant endpointVariant) []string {
list := make([]string, 0, len(es)) list := make([]string, 0, len(es))
for k := range es { for k := range es {
list = append(list, k) if k.Variant != variant {
continue
}
list = append(list, k.Region)
} }
return list return list
} }
@ -138,17 +362,17 @@ type services map[string]service
type service struct { type service struct {
PartitionEndpoint string `json:"partitionEndpoint"` PartitionEndpoint string `json:"partitionEndpoint"`
IsRegionalized boxedBool `json:"isRegionalized,omitempty"` IsRegionalized boxedBool `json:"isRegionalized,omitempty"`
Defaults endpoint `json:"defaults"` Defaults endpointDefaults `json:"defaults"`
Endpoints endpoints `json:"endpoints"` Endpoints serviceEndpoints `json:"endpoints"`
} }
func (s *service) endpointForRegion(region string) (endpoint, bool) { func (s *service) endpointForRegion(region string, endpoints serviceEndpoints, variant endpointVariant) (endpoint, bool) {
if s.IsRegionalized == boxedFalse { if e, ok := endpoints[endpointKey{Region: region, Variant: variant}]; ok {
return s.Endpoints[s.PartitionEndpoint], region == s.PartitionEndpoint return e, true
} }
if e, ok := s.Endpoints[region]; ok { if s.IsRegionalized == boxedFalse {
return e, true return endpoints[endpointKey{Region: s.PartitionEndpoint, Variant: variant}], region == s.PartitionEndpoint
} }
// Unable to find any matching endpoint, return // Unable to find any matching endpoint, return
@ -156,22 +380,73 @@ func (s *service) endpointForRegion(region string) (endpoint, bool) {
return endpoint{}, false return endpoint{}, false
} }
type endpoints map[string]endpoint type serviceEndpoints map[endpointKey]endpoint
func (s *serviceEndpoints) UnmarshalJSON(data []byte) error {
if *s == nil {
*s = make(serviceEndpoints)
}
var regionToEndpoint map[string]endpointWithVariants
if err := json.Unmarshal(data, &regionToEndpoint); err != nil {
return err
}
for region, e := range regionToEndpoint {
(*s)[endpointKey{Region: region}] = e.endpoint
e.Hostname = ""
e.DNSSuffix = ""
for _, variant := range e.Variants {
endpointVariant, unknown := parseVariantTags(variant.Tags)
if unknown {
continue
}
var ve endpoint
ve.mergeIn(e.endpoint)
ve.mergeIn(variant.endpoint)
(*s)[endpointKey{Region: region, Variant: endpointVariant}] = ve
}
}
return nil
}
type endpoint struct { type endpoint struct {
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
Protocols []string `json:"protocols"` Protocols []string `json:"protocols"`
CredentialScope credentialScope `json:"credentialScope"` CredentialScope credentialScope `json:"credentialScope"`
// Custom fields not modeled DNSSuffix string `json:"dnsSuffix"`
HasDualStack boxedBool `json:"-"`
DualStackHostname string `json:"-"`
// Signature Version not used // Signature Version not used
SignatureVersions []string `json:"signatureVersions"` SignatureVersions []string `json:"signatureVersions"`
// SSLCommonName not used. // SSLCommonName not used.
SSLCommonName string `json:"sslCommonName"` SSLCommonName string `json:"sslCommonName"`
Deprecated boxedBool `json:"deprecated"`
}
// isZero returns whether the endpoint structure is an empty (zero) value.
func (e endpoint) isZero() bool {
switch {
case len(e.Hostname) != 0:
return false
case len(e.Protocols) != 0:
return false
case e.CredentialScope != (credentialScope{}):
return false
case len(e.SignatureVersions) != 0:
return false
case len(e.SSLCommonName) != 0:
return false
}
return true
} }
const ( const (
@ -200,7 +475,7 @@ func getByPriority(s []string, p []string, def string) string {
return s[0] return s[0]
} }
func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint { func (e endpoint) resolve(service, partitionID, region, dnsSuffixTemplateVariable, dnsSuffix string, defs []endpoint, opts Options) (ResolvedEndpoint, error) {
var merged endpoint var merged endpoint
for _, def := range defs { for _, def := range defs {
merged.mergeIn(def) merged.mergeIn(def)
@ -208,20 +483,6 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
merged.mergeIn(e) merged.mergeIn(e)
e = merged e = merged
hostname := e.Hostname
// Offset the hostname for dualstack if enabled
if opts.UseDualStack && e.HasDualStack == boxedTrue {
hostname = e.DualStackHostname
}
u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1)
scheme := getEndpointScheme(e.Protocols, opts.DisableSSL)
u = fmt.Sprintf("%s://%s", scheme, u)
signingRegion := e.CredentialScope.Region signingRegion := e.CredentialScope.Region
if len(signingRegion) == 0 { if len(signingRegion) == 0 {
signingRegion = region signingRegion = region
@ -234,13 +495,35 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
signingNameDerived = true signingNameDerived = true
} }
hostname := e.Hostname
if !validateInputRegion(region) {
return ResolvedEndpoint{}, fmt.Errorf("invalid region identifier format provided")
}
if len(merged.DNSSuffix) > 0 {
dnsSuffix = merged.DNSSuffix
}
u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, dnsSuffixTemplateVariable, dnsSuffix, 1)
scheme := getEndpointScheme(e.Protocols, opts.DisableSSL)
u = fmt.Sprintf("%s://%s", scheme, u)
if e.Deprecated == boxedTrue && opts.LogDeprecated && opts.Logger != nil {
opts.Logger.Log(fmt.Sprintf("endpoint identifier %q, url %q marked as deprecated", region, u))
}
return ResolvedEndpoint{ return ResolvedEndpoint{
URL: u, URL: u,
PartitionID: partitionID,
SigningRegion: signingRegion, SigningRegion: signingRegion,
SigningName: signingName, SigningName: signingName,
SigningNameDerived: signingNameDerived, SigningNameDerived: signingNameDerived,
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner), SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
} }, nil
} }
func getEndpointScheme(protocols []string, disableSSL bool) string { func getEndpointScheme(protocols []string, disableSSL bool) string {
@ -270,11 +553,11 @@ func (e *endpoint) mergeIn(other endpoint) {
if len(other.SSLCommonName) > 0 { if len(other.SSLCommonName) > 0 {
e.SSLCommonName = other.SSLCommonName e.SSLCommonName = other.SSLCommonName
} }
if other.HasDualStack != boxedBoolUnset { if len(other.DNSSuffix) > 0 {
e.HasDualStack = other.HasDualStack e.DNSSuffix = other.DNSSuffix
} }
if len(other.DualStackHostname) > 0 { if other.Deprecated != boxedBoolUnset {
e.DualStackHostname = other.DualStackHostname e.Deprecated = other.Deprecated
} }
} }
@ -305,3 +588,7 @@ const (
boxedFalse boxedFalse
boxedTrue boxedTrue
) )
func validateInputRegion(region string) bool {
return regionValidationRegex.MatchString(region)
}

View file

@ -1,3 +1,4 @@
//go:build codegen
// +build codegen // +build codegen
package endpoints package endpoints
@ -16,6 +17,10 @@ import (
type CodeGenOptions struct { type CodeGenOptions struct {
// Options for how the model will be decoded. // Options for how the model will be decoded.
DecodeModelOptions DecodeModelOptions DecodeModelOptions DecodeModelOptions
// Disables code generation of the service endpoint prefix IDs defined in
// the model.
DisableGenerateServiceIDs bool
} }
// Set combines all of the option functions together // Set combines all of the option functions together
@ -39,8 +44,16 @@ func CodeGenModel(modelFile io.Reader, outFile io.Writer, optFns ...func(*CodeGe
return err return err
} }
v := struct {
Resolver
CodeGenOptions
}{
Resolver: resolver,
CodeGenOptions: opts,
}
tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl)) tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl))
if err := tmpl.ExecuteTemplate(outFile, "defaults", resolver); err != nil { if err := tmpl.ExecuteTemplate(outFile, "defaults", v); err != nil {
return fmt.Errorf("failed to execute template, %v", err) return fmt.Errorf("failed to execute template, %v", err)
} }
@ -142,6 +155,56 @@ func serviceSet(ps partitions) map[string]struct{} {
return set return set
} }
func endpointVariantSetter(variant endpointVariant) (string, error) {
if variant == 0 {
return "0", nil
}
if variant > (fipsVariant | dualStackVariant) {
return "", fmt.Errorf("unknown endpoint variant")
}
var symbols []string
if variant&fipsVariant != 0 {
symbols = append(symbols, "fipsVariant")
}
if variant&dualStackVariant != 0 {
symbols = append(symbols, "dualStackVariant")
}
v := strings.Join(symbols, "|")
return v, nil
}
func endpointKeySetter(e endpointKey) (string, error) {
var sb strings.Builder
sb.WriteString("endpointKey{\n")
sb.WriteString(fmt.Sprintf("Region: %q,\n", e.Region))
if e.Variant != 0 {
variantSetter, err := endpointVariantSetter(e.Variant)
if err != nil {
return "", err
}
sb.WriteString(fmt.Sprintf("Variant: %s,\n", variantSetter))
}
sb.WriteString("}")
return sb.String(), nil
}
func defaultKeySetter(e defaultKey) (string, error) {
var sb strings.Builder
sb.WriteString("defaultKey{\n")
if e.Variant != 0 {
variantSetter, err := endpointVariantSetter(e.Variant)
if err != nil {
return "", err
}
sb.WriteString(fmt.Sprintf("Variant: %s,\n", variantSetter))
}
sb.WriteString("}")
return sb.String(), nil
}
var funcMap = template.FuncMap{ var funcMap = template.FuncMap{
"ToSymbol": toSymbol, "ToSymbol": toSymbol,
"QuoteString": quoteString, "QuoteString": quoteString,
@ -154,6 +217,9 @@ var funcMap = template.FuncMap{
"StringSliceIfSet": stringSliceIfSet, "StringSliceIfSet": stringSliceIfSet,
"EndpointIsSet": endpointIsSet, "EndpointIsSet": endpointIsSet,
"ServicesSet": serviceSet, "ServicesSet": serviceSet,
"EndpointVariantSetter": endpointVariantSetter,
"EndpointKeySetter": endpointKeySetter,
"DefaultKeySetter": defaultKeySetter,
} }
const v3Tmpl = ` const v3Tmpl = `
@ -166,15 +232,17 @@ import (
"regexp" "regexp"
) )
{{ template "partition consts" . }} {{ template "partition consts" $.Resolver }}
{{ range $_, $partition := . }} {{ range $_, $partition := $.Resolver }}
{{ template "partition region consts" $partition }} {{ template "partition region consts" $partition }}
{{ end }} {{ end }}
{{ template "service consts" . }} {{ if not $.DisableGenerateServiceIDs -}}
{{ template "service consts" $.Resolver }}
{{- end }}
{{ template "endpoint resolvers" . }} {{ template "endpoint resolvers" $.Resolver }}
{{- end }} {{- end }}
{{ define "partition consts" }} {{ define "partition consts" }}
@ -257,9 +325,9 @@ partition{
{{ StringIfSet "Name: %q,\n" .Name -}} {{ StringIfSet "Name: %q,\n" .Name -}}
{{ StringIfSet "DNSSuffix: %q,\n" .DNSSuffix -}} {{ StringIfSet "DNSSuffix: %q,\n" .DNSSuffix -}}
RegionRegex: {{ template "gocode RegionRegex" .RegionRegex }}, RegionRegex: {{ template "gocode RegionRegex" .RegionRegex }},
{{ if EndpointIsSet .Defaults -}} {{ if (gt (len .Defaults) 0) -}}
Defaults: {{ template "gocode Endpoint" .Defaults }}, Defaults: {{ template "gocode Defaults" .Defaults -}},
{{- end }} {{ end -}}
Regions: {{ template "gocode Regions" .Regions }}, Regions: {{ template "gocode Regions" .Regions }},
Services: {{ template "gocode Services" .Services }}, Services: {{ template "gocode Services" .Services }},
} }
@ -300,19 +368,27 @@ services{
service{ service{
{{ StringIfSet "PartitionEndpoint: %q,\n" .PartitionEndpoint -}} {{ StringIfSet "PartitionEndpoint: %q,\n" .PartitionEndpoint -}}
{{ BoxedBoolIfSet "IsRegionalized: %s,\n" .IsRegionalized -}} {{ BoxedBoolIfSet "IsRegionalized: %s,\n" .IsRegionalized -}}
{{ if EndpointIsSet .Defaults -}} {{ if (gt (len .Defaults) 0) -}}
Defaults: {{ template "gocode Endpoint" .Defaults -}}, Defaults: {{ template "gocode Defaults" .Defaults -}},
{{- end }} {{ end -}}
{{ if .Endpoints -}} {{ if .Endpoints -}}
Endpoints: {{ template "gocode Endpoints" .Endpoints }}, Endpoints: {{ template "gocode Endpoints" .Endpoints }},
{{- end }} {{- end }}
} }
{{- end }} {{- end }}
{{ define "gocode Endpoints" -}} {{ define "gocode Defaults" -}}
endpoints{ endpointDefaults{
{{ range $id, $endpoint := . -}} {{ range $id, $endpoint := . -}}
"{{ $id }}": {{ template "gocode Endpoint" $endpoint }}, {{ DefaultKeySetter $id }}: {{ template "gocode Endpoint" $endpoint }},
{{ end }}
}
{{- end }}
{{ define "gocode Endpoints" -}}
serviceEndpoints{
{{ range $id, $endpoint := . -}}
{{ EndpointKeySetter $id }}: {{ template "gocode Endpoint" $endpoint }},
{{ end }} {{ end }}
} }
{{- end }} {{- end }}
@ -320,6 +396,7 @@ endpoints{
{{ define "gocode Endpoint" -}} {{ define "gocode Endpoint" -}}
endpoint{ endpoint{
{{ StringIfSet "Hostname: %q,\n" .Hostname -}} {{ StringIfSet "Hostname: %q,\n" .Hostname -}}
{{ StringIfSet "DNSSuffix: %q,\n" .DNSSuffix -}}
{{ StringIfSet "SSLCommonName: %q,\n" .SSLCommonName -}} {{ StringIfSet "SSLCommonName: %q,\n" .SSLCommonName -}}
{{ StringSliceIfSet "Protocols: []string{%s},\n" .Protocols -}} {{ StringSliceIfSet "Protocols: []string{%s},\n" .Protocols -}}
{{ StringSliceIfSet "SignatureVersions: []string{%s},\n" .SignatureVersions -}} {{ StringSliceIfSet "SignatureVersions: []string{%s},\n" .SignatureVersions -}}
@ -329,9 +406,7 @@ endpoint{
{{ StringIfSet "Service: %q,\n" .CredentialScope.Service -}} {{ StringIfSet "Service: %q,\n" .CredentialScope.Service -}}
}, },
{{- end }} {{- end }}
{{ BoxedBoolIfSet "HasDualStack: %s,\n" .HasDualStack -}} {{ BoxedBoolIfSet "Deprecated: %s,\n" .Deprecated -}}
{{ StringIfSet "DualStackHostname: %q,\n" .DualStackHostname -}}
} }
{{- end }} {{- end }}
` `

View file

@ -5,13 +5,9 @@ import "github.com/aws/aws-sdk-go/aws/awserr"
var ( var (
// ErrMissingRegion is an error that is returned if region configuration is // ErrMissingRegion is an error that is returned if region configuration is
// not found. // not found.
//
// @readonly
ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil) ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be // ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service. // resolved for a service.
//
// @readonly
ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil) ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
) )

View file

@ -77,6 +77,9 @@ const (
// wire unmarshaled message content of requests and responses made while // wire unmarshaled message content of requests and responses made while
// using the SDK Will also enable LogDebug. // using the SDK Will also enable LogDebug.
LogDebugWithEventStreamBody LogDebugWithEventStreamBody
// LogDebugWithDeprecated states the SDK should log details about deprecated functionality.
LogDebugWithDeprecated
) )
// A Logger is a minimalistic interface for the SDK to log messages to. Should // A Logger is a minimalistic interface for the SDK to log messages to. Should

View file

@ -1,18 +1,18 @@
// +build !appengine,!plan9
package request package request
import ( import (
"net" "strings"
"os"
"syscall"
) )
func isErrConnectionReset(err error) bool { func isErrConnectionReset(err error) bool {
if opErr, ok := err.(*net.OpError); ok { if strings.Contains(err.Error(), "read: connection reset") {
if sysErr, ok := opErr.Err.(*os.SyscallError); ok { return false
return sysErr.Err == syscall.ECONNRESET
} }
if strings.Contains(err.Error(), "use of closed network connection") ||
strings.Contains(err.Error(), "connection reset") ||
strings.Contains(err.Error(), "broken pipe") {
return true
} }
return false return false

View file

@ -1,11 +0,0 @@
// +build appengine plan9
package request
import (
"strings"
)
func isErrConnectionReset(err error) bool {
return strings.Contains(err.Error(), "connection reset")
}

View file

@ -10,6 +10,7 @@ import (
type Handlers struct { type Handlers struct {
Validate HandlerList Validate HandlerList
Build HandlerList Build HandlerList
BuildStream HandlerList
Sign HandlerList Sign HandlerList
Send HandlerList Send HandlerList
ValidateResponse HandlerList ValidateResponse HandlerList
@ -19,14 +20,16 @@ type Handlers struct {
UnmarshalError HandlerList UnmarshalError HandlerList
Retry HandlerList Retry HandlerList
AfterRetry HandlerList AfterRetry HandlerList
CompleteAttempt HandlerList
Complete HandlerList Complete HandlerList
} }
// Copy returns of this handler's lists. // Copy returns a copy of this handler's lists.
func (h *Handlers) Copy() Handlers { func (h *Handlers) Copy() Handlers {
return Handlers{ return Handlers{
Validate: h.Validate.copy(), Validate: h.Validate.copy(),
Build: h.Build.copy(), Build: h.Build.copy(),
BuildStream: h.BuildStream.copy(),
Sign: h.Sign.copy(), Sign: h.Sign.copy(),
Send: h.Send.copy(), Send: h.Send.copy(),
ValidateResponse: h.ValidateResponse.copy(), ValidateResponse: h.ValidateResponse.copy(),
@ -36,14 +39,16 @@ func (h *Handlers) Copy() Handlers {
UnmarshalMeta: h.UnmarshalMeta.copy(), UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(), Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(), AfterRetry: h.AfterRetry.copy(),
CompleteAttempt: h.CompleteAttempt.copy(),
Complete: h.Complete.copy(), Complete: h.Complete.copy(),
} }
} }
// Clear removes callback functions for all handlers // Clear removes callback functions for all handlers.
func (h *Handlers) Clear() { func (h *Handlers) Clear() {
h.Validate.Clear() h.Validate.Clear()
h.Build.Clear() h.Build.Clear()
h.BuildStream.Clear()
h.Send.Clear() h.Send.Clear()
h.Sign.Clear() h.Sign.Clear()
h.Unmarshal.Clear() h.Unmarshal.Clear()
@ -53,9 +58,58 @@ func (h *Handlers) Clear() {
h.ValidateResponse.Clear() h.ValidateResponse.Clear()
h.Retry.Clear() h.Retry.Clear()
h.AfterRetry.Clear() h.AfterRetry.Clear()
h.CompleteAttempt.Clear()
h.Complete.Clear() h.Complete.Clear()
} }
// IsEmpty returns if there are no handlers in any of the handlerlists.
func (h *Handlers) IsEmpty() bool {
if h.Validate.Len() != 0 {
return false
}
if h.Build.Len() != 0 {
return false
}
if h.BuildStream.Len() != 0 {
return false
}
if h.Send.Len() != 0 {
return false
}
if h.Sign.Len() != 0 {
return false
}
if h.Unmarshal.Len() != 0 {
return false
}
if h.UnmarshalStream.Len() != 0 {
return false
}
if h.UnmarshalMeta.Len() != 0 {
return false
}
if h.UnmarshalError.Len() != 0 {
return false
}
if h.ValidateResponse.Len() != 0 {
return false
}
if h.Retry.Len() != 0 {
return false
}
if h.AfterRetry.Len() != 0 {
return false
}
if h.CompleteAttempt.Len() != 0 {
return false
}
if h.Complete.Len() != 0 {
return false
}
return true
}
// A HandlerListRunItem represents an entry in the HandlerList which // A HandlerListRunItem represents an entry in the HandlerList which
// is being run. // is being run.
type HandlerListRunItem struct { type HandlerListRunItem struct {
@ -272,3 +326,18 @@ func MakeAddToUserAgentFreeFormHandler(s string) func(*Request) {
AddToUserAgent(r, s) AddToUserAgent(r, s)
} }
} }
// WithSetRequestHeaders updates the operation request's HTTP header to contain
// the header key value pairs provided. If the header key already exists in the
// request's HTTP header set, the existing value(s) will be replaced.
func WithSetRequestHeaders(h map[string]string) Option {
return withRequestHeader(h).SetRequestHeaders
}
type withRequestHeader map[string]string
func (h withRequestHeader) SetRequestHeaders(r *Request) {
for k, v := range h {
r.HTTPRequest.Header[k] = []string{v}
}
}

View file

@ -15,12 +15,15 @@ type offsetReader struct {
closed bool closed bool
} }
func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader { func newOffsetReader(buf io.ReadSeeker, offset int64) (*offsetReader, error) {
reader := &offsetReader{} reader := &offsetReader{}
buf.Seek(offset, sdkio.SeekStart) _, err := buf.Seek(offset, sdkio.SeekStart)
if err != nil {
return nil, err
}
reader.buf = buf reader.buf = buf
return reader return reader, nil
} }
// Close will close the instance of the offset reader's access to // Close will close the instance of the offset reader's access to
@ -54,7 +57,9 @@ func (o *offsetReader) Seek(offset int64, whence int) (int64, error) {
// CloseAndCopy will return a new offsetReader with a copy of the old buffer // CloseAndCopy will return a new offsetReader with a copy of the old buffer
// and close the old buffer. // and close the old buffer.
func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader { func (o *offsetReader) CloseAndCopy(offset int64) (*offsetReader, error) {
o.Close() if err := o.Close(); err != nil {
return nil, err
}
return newOffsetReader(o.buf, offset) return newOffsetReader(o.buf, offset)
} }

View file

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"net" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
@ -37,6 +37,10 @@ const (
// API request that was canceled. Requests given a aws.Context may // API request that was canceled. Requests given a aws.Context may
// return this error when canceled. // return this error when canceled.
CanceledErrorCode = "RequestCanceled" CanceledErrorCode = "RequestCanceled"
// ErrCodeRequestError is an error preventing the SDK from continuing to
// process the request.
ErrCodeRequestError = "RequestError"
) )
// A Request is the service request to be made. // A Request is the service request to be made.
@ -52,6 +56,7 @@ type Request struct {
HTTPRequest *http.Request HTTPRequest *http.Request
HTTPResponse *http.Response HTTPResponse *http.Response
Body io.ReadSeeker Body io.ReadSeeker
streamingBody io.ReadCloser
BodyStart int64 // offset from beginning of Body that the request body starts BodyStart int64 // offset from beginning of Body that the request body starts
Params interface{} Params interface{}
Error error Error error
@ -65,6 +70,15 @@ type Request struct {
LastSignedAt time.Time LastSignedAt time.Time
DisableFollowRedirects bool DisableFollowRedirects bool
// Additional API error codes that should be retried. IsErrorRetryable
// will consider these codes in addition to its built in cases.
RetryErrorCodes []string
// Additional API error codes that should be retried with throttle backoff
// delay. IsErrorThrottle will consider these codes in addition to its
// built in cases.
ThrottleErrorCodes []string
// A value greater than 0 instructs the request to be signed as Presigned URL // A value greater than 0 instructs the request to be signed as Presigned URL
// You should not set this field directly. Instead use Request's // You should not set this field directly. Instead use Request's
// Presign or PresignRequest methods. // Presign or PresignRequest methods.
@ -91,8 +105,12 @@ type Operation struct {
BeforePresignFn func(r *Request) error BeforePresignFn func(r *Request) error
} }
// New returns a new Request pointer for the service API // New returns a new Request pointer for the service API operation and
// operation and parameters. // parameters.
//
// A Retryer should be provided to direct how the request is retried. If
// Retryer is nil, a default no retry value will be used. You can use
// NoOpRetryer in the Client package to disable retry behavior directly.
// //
// Params is any value of input parameters to be the request payload. // Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response // Data is pointer value to an object which the request's response
@ -100,6 +118,10 @@ type Operation struct {
func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers, func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request { retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request {
if retryer == nil {
retryer = noOpRetryer{}
}
method := operation.HTTPMethod method := operation.HTTPMethod
if method == "" { if method == "" {
method = "POST" method = "POST"
@ -108,13 +130,26 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
httpReq, _ := http.NewRequest(method, "", nil) httpReq, _ := http.NewRequest(method, "", nil)
var err error var err error
httpReq.URL, err = url.Parse(clientInfo.Endpoint + operation.HTTPPath) httpReq.URL, err = url.Parse(clientInfo.Endpoint)
if err != nil { if err != nil {
httpReq.URL = &url.URL{} httpReq.URL = &url.URL{}
err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err) err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err)
} }
SanitizeHostForHeader(httpReq) if len(operation.HTTPPath) != 0 {
opHTTPPath := operation.HTTPPath
var opQueryString string
if idx := strings.Index(opHTTPPath, "?"); idx >= 0 {
opQueryString = opHTTPPath[idx+1:]
opHTTPPath = opHTTPPath[:idx]
}
if strings.HasSuffix(httpReq.URL.Path, "/") && strings.HasPrefix(opHTTPPath, "/") {
opHTTPPath = opHTTPPath[1:]
}
httpReq.URL.Path += opHTTPPath
httpReq.URL.RawQuery = opQueryString
}
r := &Request{ r := &Request{
Config: cfg, Config: cfg,
@ -122,7 +157,6 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
Handlers: handlers.Copy(), Handlers: handlers.Copy(),
Retryer: retryer, Retryer: retryer,
AttemptTime: time.Now(),
Time: time.Now(), Time: time.Now(),
ExpireTime: 0, ExpireTime: 0,
Operation: operation, Operation: operation,
@ -233,6 +267,10 @@ func (r *Request) WillRetry() bool {
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries() return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
} }
func fmtAttemptCount(retryCount, maxRetries int) string {
return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries)
}
// ParamsFilled returns if the request's parameters have been populated // ParamsFilled returns if the request's parameters have been populated
// and the parameters are valid. False is returned if no parameters are // and the parameters are valid. False is returned if no parameters are
// provided or invalid. // provided or invalid.
@ -261,12 +299,32 @@ func (r *Request) SetStringBody(s string) {
// SetReaderBody will set the request's body reader. // SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) { func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.Body = reader r.Body = reader
r.BodyStart, _ = reader.Seek(0, sdkio.SeekCurrent) // Get the Bodies current offset.
if aws.IsReaderSeekable(reader) {
var err error
// Get the Bodies current offset so retries will start from the same
// initial position.
r.BodyStart, err = reader.Seek(0, sdkio.SeekCurrent)
if err != nil {
r.Error = awserr.New(ErrCodeSerialization,
"failed to determine start of request body", err)
return
}
}
r.ResetBody() r.ResetBody()
} }
// SetStreamingBody set the reader to be used for the request that will stream
// bytes to the server. Request's Body must not be set to any reader.
func (r *Request) SetStreamingBody(reader io.ReadCloser) {
r.streamingBody = reader
r.SetReaderBody(aws.ReadSeekCloser(reader))
}
// Presign returns the request's signed URL. Error will be returned // Presign returns the request's signed URL. Error will be returned
// if the signing fails. // if the signing fails. The expire parameter is only used for presigned Amazon
// S3 API requests. All other AWS services will use a fixed expiration
// time of 15 minutes.
// //
// It is invalid to create a presigned URL with a expire duration 0 or less. An // It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less. // error is returned if expire duration is 0 or less.
@ -283,7 +341,9 @@ func (r *Request) Presign(expire time.Duration) (string, error) {
} }
// PresignRequest behaves just like presign, with the addition of returning a // PresignRequest behaves just like presign, with the addition of returning a
// set of headers that were signed. // set of headers that were signed. The expire parameter is only used for
// presigned Amazon S3 API requests. All other AWS services will use a fixed
// expiration time of 15 minutes.
// //
// It is invalid to create a presigned URL with a expire duration 0 or less. An // It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less. // error is returned if expire duration is 0 or less.
@ -328,16 +388,15 @@ func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, err
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
} }
func debugLogReqError(r *Request, stage string, retrying bool, err error) { const (
notRetrying = "not retrying"
)
func debugLogReqError(r *Request, stage, retryStr string, err error) {
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) { if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
return return
} }
retryStr := "not retrying"
if retrying {
retryStr = "will retry"
}
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v", r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err)) stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
} }
@ -356,12 +415,12 @@ func (r *Request) Build() error {
if !r.built { if !r.built {
r.Handlers.Validate.Run(r) r.Handlers.Validate.Run(r)
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Validate Request", false, r.Error) debugLogReqError(r, "Validate Request", notRetrying, r.Error)
return r.Error return r.Error
} }
r.Handlers.Build.Run(r) r.Handlers.Build.Run(r)
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error) debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error return r.Error
} }
r.built = true r.built = true
@ -377,20 +436,30 @@ func (r *Request) Build() error {
func (r *Request) Sign() error { func (r *Request) Sign() error {
r.Build() r.Build()
if r.Error != nil { if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error) debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error return r.Error
} }
SanitizeHostForHeader(r.HTTPRequest)
r.Handlers.Sign.Run(r) r.Handlers.Sign.Run(r)
return r.Error return r.Error
} }
func (r *Request) getNextRequestBody() (io.ReadCloser, error) { func (r *Request) getNextRequestBody() (body io.ReadCloser, err error) {
if r.streamingBody != nil {
return r.streamingBody, nil
}
if r.safeBody != nil { if r.safeBody != nil {
r.safeBody.Close() r.safeBody.Close()
} }
r.safeBody = newOffsetReader(r.Body, r.BodyStart) r.safeBody, err = newOffsetReader(r.Body, r.BodyStart)
if err != nil {
return nil, awserr.New(ErrCodeSerialization,
"failed to get next request body reader", err)
}
// Go 1.8 tightened and clarified the rules code needs to use when building // Go 1.8 tightened and clarified the rules code needs to use when building
// requests with the http package. Go 1.8 removed the automatic detection // requests with the http package. Go 1.8 removed the automatic detection
@ -407,10 +476,10 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
// Related golang/go#18257 // Related golang/go#18257
l, err := aws.SeekerLen(r.Body) l, err := aws.SeekerLen(r.Body)
if err != nil { if err != nil {
return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err) return nil, awserr.New(ErrCodeSerialization,
"failed to compute request body size", err)
} }
var body io.ReadCloser
if l == 0 { if l == 0 {
body = NoBody body = NoBody
} else if l > 0 { } else if l > 0 {
@ -457,14 +526,50 @@ func (r *Request) GetBody() io.ReadSeeker {
// Send will not close the request.Request's body. // Send will not close the request.Request's body.
func (r *Request) Send() error { func (r *Request) Send() error {
defer func() { defer func() {
// Ensure a non-nil HTTPResponse parameter is set to ensure handlers
// checking for HTTPResponse values, don't fail.
if r.HTTPResponse == nil {
r.HTTPResponse = &http.Response{
Header: http.Header{},
Body: ioutil.NopCloser(&bytes.Buffer{}),
}
}
// Regardless of success or failure of the request trigger the Complete // Regardless of success or failure of the request trigger the Complete
// request handlers. // request handlers.
r.Handlers.Complete.Run(r) r.Handlers.Complete.Run(r)
}() }()
if err := r.Error; err != nil {
return err
}
for { for {
r.Error = nil
r.AttemptTime = time.Now() r.AttemptTime = time.Now()
if aws.BoolValue(r.Retryable) {
if err := r.Sign(); err != nil {
debugLogReqError(r, "Sign Request", notRetrying, err)
return err
}
if err := r.sendRequest(); err == nil {
return nil
}
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil || !aws.BoolValue(r.Retryable) {
return r.Error
}
if err := r.prepareRetry(); err != nil {
r.Error = err
return err
}
}
}
func (r *Request) prepareRetry() error {
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) { if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d", r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount)) r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
@ -475,68 +580,50 @@ func (r *Request) Send() error {
// the request's body even though the Client's Do returned. // the request's body even though the Client's Do returned.
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil) r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil)
r.ResetBody() r.ResetBody()
if err := r.Error; err != nil {
return awserr.New(ErrCodeSerialization,
"failed to prepare body for retry", err)
}
// Closing response body to ensure that no response body is leaked // Closing response body to ensure that no response body is leaked
// between retry attempts. // between retry attempts.
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil { if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
r.HTTPResponse.Body.Close() r.HTTPResponse.Body.Close()
} }
}
r.Sign() return nil
if r.Error != nil { }
return r.Error
} func (r *Request) sendRequest() (sendErr error) {
defer r.Handlers.CompleteAttempt.Run(r)
r.Retryable = nil r.Retryable = nil
r.Handlers.Send.Run(r) r.Handlers.Send.Run(r)
if r.Error != nil { if r.Error != nil {
if !shouldRetryCancel(r) { debugLogReqError(r, "Send Request",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error return r.Error
} }
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Send Request", false, err)
return r.Error
}
debugLogReqError(r, "Send Request", true, err)
continue
}
r.Handlers.UnmarshalMeta.Run(r) r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r) r.Handlers.ValidateResponse.Run(r)
if r.Error != nil { if r.Error != nil {
r.Handlers.UnmarshalError.Run(r) r.Handlers.UnmarshalError.Run(r)
err := r.Error debugLogReqError(r, "Validate Response",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Handlers.Retry.Run(r) r.Error)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Response", false, err)
return r.Error return r.Error
} }
debugLogReqError(r, "Validate Response", true, err)
continue
}
r.Handlers.Unmarshal.Run(r) r.Handlers.Unmarshal.Run(r)
if r.Error != nil { if r.Error != nil {
err := r.Error debugLogReqError(r, "Unmarshal Response",
r.Handlers.Retry.Run(r) fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Handlers.AfterRetry.Run(r) r.Error)
if r.Error != nil {
debugLogReqError(r, "Unmarshal Response", false, err)
return r.Error return r.Error
} }
debugLogReqError(r, "Unmarshal Response", true, err)
continue
}
break
}
return nil return nil
} }
@ -561,32 +648,6 @@ func AddToUserAgent(r *Request, s string) {
r.HTTPRequest.Header.Set("User-Agent", s) r.HTTPRequest.Header.Set("User-Agent", s)
} }
func shouldRetryCancel(r *Request) bool {
awsErr, ok := r.Error.(awserr.Error)
timeoutErr := false
errStr := r.Error.Error()
if ok {
if awsErr.Code() == CanceledErrorCode {
return false
}
err := awsErr.OrigErr()
netErr, netOK := err.(net.Error)
timeoutErr = netOK && netErr.Temporary()
if urlErr, ok := err.(*url.Error); !timeoutErr && ok {
errStr = urlErr.Err.Error()
}
}
// There can be two types of canceled errors here.
// The first being a net.Error and the other being an error.
// If the request was timed out, we want to continue the retry
// process. Otherwise, return the canceled error.
return timeoutErr ||
(errStr != "net/http: request canceled" &&
errStr != "net/http: request canceled while waiting for connection")
}
// SanitizeHostForHeader removes default port from host and updates request.Host // SanitizeHostForHeader removes default port from host and updates request.Host
func SanitizeHostForHeader(r *http.Request) { func SanitizeHostForHeader(r *http.Request) {
host := getHost(r) host := getHost(r)
@ -602,6 +663,10 @@ func getHost(r *http.Request) string {
return r.Host return r.Host
} }
if r.URL == nil {
return ""
}
return r.URL.Host return r.URL.Host
} }

View file

@ -1,3 +1,4 @@
//go:build !go1.8
// +build !go1.8 // +build !go1.8
package request package request

View file

@ -1,9 +1,12 @@
//go:build go1.8
// +build go1.8 // +build go1.8
package request package request
import ( import (
"net/http" "net/http"
"github.com/aws/aws-sdk-go/aws/awserr"
) )
// NoBody is a http.NoBody reader instructing Go HTTP client to not include // NoBody is a http.NoBody reader instructing Go HTTP client to not include
@ -24,7 +27,8 @@ var NoBody = http.NoBody
func (r *Request) ResetBody() { func (r *Request) ResetBody() {
body, err := r.getNextRequestBody() body, err := r.getNextRequestBody()
if err != nil { if err != nil {
r.Error = err r.Error = awserr.New(ErrCodeSerialization,
"failed to reset request body", err)
return return
} }

View file

@ -1,3 +1,4 @@
//go:build go1.7
// +build go1.7 // +build go1.7
package request package request

View file

@ -1,3 +1,4 @@
//go:build !go1.7
// +build !go1.7 // +build !go1.7
package request package request

View file

@ -17,11 +17,13 @@ import (
// does the pagination between API operations, and Paginator defines the // does the pagination between API operations, and Paginator defines the
// configuration that will be used per page request. // configuration that will be used per page request.
// //
// cont := true // for p.Next() {
// for p.Next() && cont {
// data := p.Page().(*s3.ListObjectsOutput) // data := p.Page().(*s3.ListObjectsOutput)
// // process the page's data // // process the page's data
// // ...
// // break out of loop to stop fetching additional pages
// } // }
//
// return p.Err() // return p.Err()
// //
// See service client API operation Pages methods for examples how the SDK will // See service client API operation Pages methods for examples how the SDK will
@ -146,7 +148,7 @@ func (r *Request) nextPageTokens() []interface{} {
return nil return nil
} }
case bool: case bool:
if v == false { if !v {
return nil return nil
} }
} }

View file

@ -1,32 +1,81 @@
package request package request
import ( import (
"net"
"net/url"
"strings"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
) )
// Retryer is an interface to control retry logic for a given service. // Retryer provides the interface drive the SDK's request retry behavior. The
// The default implementation used by most services is the client.DefaultRetryer // Retryer implementation is responsible for implementing exponential backoff,
// structure, which contains basic retry logic using exponential backoff. // and determine if a request API error should be retried.
//
// client.DefaultRetryer is the SDK's default implementation of the Retryer. It
// uses the Request.IsErrorRetryable and Request.IsErrorThrottle methods to
// determine if the request is retried.
type Retryer interface { type Retryer interface {
// RetryRules return the retry delay that should be used by the SDK before
// making another request attempt for the failed request.
RetryRules(*Request) time.Duration RetryRules(*Request) time.Duration
// ShouldRetry returns if the failed request is retryable.
//
// Implementations may consider request attempt count when determining if a
// request is retryable, but the SDK will use MaxRetries to limit the
// number of attempts a request are made.
ShouldRetry(*Request) bool ShouldRetry(*Request) bool
// MaxRetries is the number of times a request may be retried before
// failing.
MaxRetries() int MaxRetries() int
} }
// WithRetryer sets a config Retryer value to the given Config returning it // WithRetryer sets a Retryer value to the given Config returning the Config
// for chaining. // value for chaining. The value must not be nil.
func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config { func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config {
if retryer == nil {
if cfg.Logger != nil {
cfg.Logger.Log("ERROR: Request.WithRetryer called with nil retryer. Replacing with retry disabled Retryer.")
}
retryer = noOpRetryer{}
}
cfg.Retryer = retryer cfg.Retryer = retryer
return cfg return cfg
}
// noOpRetryer is a internal no op retryer used when a request is created
// without a retryer.
//
// Provides a retryer that performs no retries.
// It should be used when we do not want retries to be performed.
type noOpRetryer struct{}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
func (d noOpRetryer) MaxRetries() int {
return 0
}
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
func (d noOpRetryer) ShouldRetry(_ *Request) bool {
return false
}
// RetryRules returns the delay duration before retrying this request again;
// since NoOpRetryer does not retry, RetryRules always returns 0.
func (d noOpRetryer) RetryRules(_ *Request) time.Duration {
return 0
} }
// retryableCodes is a collection of service response codes which are retry-able // retryableCodes is a collection of service response codes which are retry-able
// without any further action. // without any further action.
var retryableCodes = map[string]struct{}{ var retryableCodes = map[string]struct{}{
"RequestError": {}, ErrCodeRequestError: {},
"RequestTimeout": {}, "RequestTimeout": {},
ErrCodeResponseTimeout: {}, ErrCodeResponseTimeout: {},
"RequestTimeoutException": {}, // Glacier's flavor of RequestTimeout "RequestTimeoutException": {}, // Glacier's flavor of RequestTimeout
@ -34,12 +83,16 @@ var retryableCodes = map[string]struct{}{
var throttleCodes = map[string]struct{}{ var throttleCodes = map[string]struct{}{
"ProvisionedThroughputExceededException": {}, "ProvisionedThroughputExceededException": {},
"ThrottledException": {}, // SNS, XRay, ResourceGroupsTagging API
"Throttling": {}, "Throttling": {},
"ThrottlingException": {}, "ThrottlingException": {},
"RequestLimitExceeded": {}, "RequestLimitExceeded": {},
"RequestThrottled": {}, "RequestThrottled": {},
"RequestThrottledException": {},
"TooManyRequestsException": {}, // Lambda functions "TooManyRequestsException": {}, // Lambda functions
"PriorRequestNotComplete": {}, // Route53 "PriorRequestNotComplete": {}, // Route53
"TransactionInProgressException": {},
"EC2ThrottledException": {}, // EC2
} }
// credsExpiredCodes is a collection of error codes which signify the credentials // credsExpiredCodes is a collection of error codes which signify the credentials
@ -74,10 +127,6 @@ var validParentCodes = map[string]struct{}{
ErrCodeRead: {}, ErrCodeRead: {},
} }
type temporaryError interface {
Temporary() bool
}
func isNestedErrorRetryable(parentErr awserr.Error) bool { func isNestedErrorRetryable(parentErr awserr.Error) bool {
if parentErr == nil { if parentErr == nil {
return false return false
@ -96,7 +145,7 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
return isCodeRetryable(aerr.Code()) return isCodeRetryable(aerr.Code())
} }
if t, ok := err.(temporaryError); ok { if t, ok := err.(temporary); ok {
return t.Temporary() || isErrConnectionReset(err) return t.Temporary() || isErrConnectionReset(err)
} }
@ -106,33 +155,91 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
// IsErrorRetryable returns whether the error is retryable, based on its Code. // IsErrorRetryable returns whether the error is retryable, based on its Code.
// Returns false if error is nil. // Returns false if error is nil.
func IsErrorRetryable(err error) bool { func IsErrorRetryable(err error) bool {
if err != nil { if err == nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeRetryable(aerr.Code()) || isNestedErrorRetryable(aerr)
}
}
return false return false
}
return shouldRetryError(err)
}
type temporary interface {
Temporary() bool
}
func shouldRetryError(origErr error) bool {
switch err := origErr.(type) {
case awserr.Error:
if err.Code() == CanceledErrorCode {
return false
}
if isNestedErrorRetryable(err) {
return true
}
origErr := err.OrigErr()
var shouldRetry bool
if origErr != nil {
shouldRetry = shouldRetryError(origErr)
if err.Code() == ErrCodeRequestError && !shouldRetry {
return false
}
}
if isCodeRetryable(err.Code()) {
return true
}
return shouldRetry
case *url.Error:
if strings.Contains(err.Error(), "connection refused") {
// Refused connections should be retried as the service may not yet
// be running on the port. Go TCP dial considers refused
// connections as not temporary.
return true
}
// *url.Error only implements Temporary after golang 1.6 but since
// url.Error only wraps the error:
return shouldRetryError(err.Err)
case temporary:
if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" {
return true
}
// If the error is temporary, we want to allow continuation of the
// retry process
return err.Temporary() || isErrConnectionReset(origErr)
case nil:
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
// because we don't know the cause, it is marked as retryable. See
// TestRequest4xxUnretryable for an example.
return true
default:
switch err.Error() {
case "net/http: request canceled",
"net/http: request canceled while waiting for connection":
// known 1.5 error case when an http request is cancelled
return false
}
// here we don't know the error; so we allow a retry.
return true
}
} }
// IsErrorThrottle returns whether the error is to be throttled based on its code. // IsErrorThrottle returns whether the error is to be throttled based on its code.
// Returns false if error is nil. // Returns false if error is nil.
func IsErrorThrottle(err error) bool { func IsErrorThrottle(err error) bool {
if err != nil { if aerr, ok := err.(awserr.Error); ok && aerr != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeThrottle(aerr.Code()) return isCodeThrottle(aerr.Code())
} }
}
return false return false
} }
// IsErrorExpiredCreds returns whether the error code is a credential expiry error. // IsErrorExpiredCreds returns whether the error code is a credential expiry
// Returns false if error is nil. // error. Returns false if error is nil.
func IsErrorExpiredCreds(err error) bool { func IsErrorExpiredCreds(err error) bool {
if err != nil { if aerr, ok := err.(awserr.Error); ok && aerr != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeExpiredCreds(aerr.Code()) return isCodeExpiredCreds(aerr.Code())
} }
}
return false return false
} }
@ -141,17 +248,58 @@ func IsErrorExpiredCreds(err error) bool {
// //
// Alias for the utility function IsErrorRetryable // Alias for the utility function IsErrorRetryable
func (r *Request) IsErrorRetryable() bool { func (r *Request) IsErrorRetryable() bool {
if isErrCode(r.Error, r.RetryErrorCodes) {
return true
}
// HTTP response status code 501 should not be retried.
// 501 represents Not Implemented which means the request method is not
// supported by the server and cannot be handled.
if r.HTTPResponse != nil {
// HTTP response status code 500 represents internal server error and
// should be retried without any throttle.
if r.HTTPResponse.StatusCode == 500 {
return true
}
}
return IsErrorRetryable(r.Error) return IsErrorRetryable(r.Error)
} }
// IsErrorThrottle returns whether the error is to be throttled based on its code. // IsErrorThrottle returns whether the error is to be throttled based on its
// Returns false if the request has no Error set // code. Returns false if the request has no Error set.
// //
// Alias for the utility function IsErrorThrottle // Alias for the utility function IsErrorThrottle
func (r *Request) IsErrorThrottle() bool { func (r *Request) IsErrorThrottle() bool {
if isErrCode(r.Error, r.ThrottleErrorCodes) {
return true
}
if r.HTTPResponse != nil {
switch r.HTTPResponse.StatusCode {
case
429, // error caused due to too many requests
502, // Bad Gateway error should be throttled
503, // caused when service is unavailable
504: // error occurred due to gateway timeout
return true
}
}
return IsErrorThrottle(r.Error) return IsErrorThrottle(r.Error)
} }
func isErrCode(err error, codes []string) bool {
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
for _, code := range codes {
if code == aerr.Code() {
return true
}
}
}
return false
}
// IsErrorExpired returns whether the error code is a credential expiry error. // IsErrorExpired returns whether the error code is a credential expiry error.
// Returns false if the request has no Error set. // Returns false if the request has no Error set.
// //

View file

@ -17,6 +17,12 @@ const (
ParamMinValueErrCode = "ParamMinValueError" ParamMinValueErrCode = "ParamMinValueError"
// ParamMinLenErrCode is the error code for fields without enough elements. // ParamMinLenErrCode is the error code for fields without enough elements.
ParamMinLenErrCode = "ParamMinLenError" ParamMinLenErrCode = "ParamMinLenError"
// ParamMaxLenErrCode is the error code for value being too long.
ParamMaxLenErrCode = "ParamMaxLenError"
// ParamFormatErrCode is the error code for a field with invalid
// format or characters.
ParamFormatErrCode = "ParamFormatInvalidError"
) )
// Validator provides a way for types to perform validation logic on their // Validator provides a way for types to perform validation logic on their
@ -232,3 +238,49 @@ func NewErrParamMinLen(field string, min int) *ErrParamMinLen {
func (e *ErrParamMinLen) MinLen() int { func (e *ErrParamMinLen) MinLen() int {
return e.min return e.min
} }
// An ErrParamMaxLen represents a maximum length parameter error.
type ErrParamMaxLen struct {
errInvalidParam
max int
}
// NewErrParamMaxLen creates a new maximum length parameter error.
func NewErrParamMaxLen(field string, max int, value string) *ErrParamMaxLen {
return &ErrParamMaxLen{
errInvalidParam: errInvalidParam{
code: ParamMaxLenErrCode,
field: field,
msg: fmt.Sprintf("maximum size of %v, %v", max, value),
},
max: max,
}
}
// MaxLen returns the field's required minimum length.
func (e *ErrParamMaxLen) MaxLen() int {
return e.max
}
// An ErrParamFormat represents a invalid format parameter error.
type ErrParamFormat struct {
errInvalidParam
format string
}
// NewErrParamFormat creates a new invalid format parameter error.
func NewErrParamFormat(field string, format, value string) *ErrParamFormat {
return &ErrParamFormat{
errInvalidParam: errInvalidParam{
code: ParamFormatErrCode,
field: field,
msg: fmt.Sprintf("format %v, %v", format, value),
},
format: format,
}
}
// Format returns the field's required format.
func (e *ErrParamFormat) Format() string {
return e.format
}

View file

@ -0,0 +1,303 @@
package session
import (
"fmt"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/credentials/ssocreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/aws/aws-sdk-go/service/sts"
)
// CredentialsProviderOptions specifies additional options for configuring
// credentials providers.
type CredentialsProviderOptions struct {
// WebIdentityRoleProviderOptions configures a WebIdentityRoleProvider,
// such as setting its ExpiryWindow.
WebIdentityRoleProviderOptions func(*stscreds.WebIdentityRoleProvider)
}
func resolveCredentials(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (*credentials.Credentials, error) {
switch {
case len(sessOpts.Profile) != 0:
// User explicitly provided an Profile in the session's configuration
// so load that profile from shared config first.
// Github(aws/aws-sdk-go#2727)
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
case envCfg.Creds.HasKeys():
// Environment credentials
return credentials.NewStaticCredentialsFromCreds(envCfg.Creds), nil
case len(envCfg.WebIdentityTokenFilePath) != 0:
// Web identity token from environment, RoleARN required to also be
// set.
return assumeWebIdentity(cfg, handlers,
envCfg.WebIdentityTokenFilePath,
envCfg.RoleARN,
envCfg.RoleSessionName,
sessOpts.CredentialsProviderOptions,
)
default:
// Fallback to the "default" credential resolution chain.
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
}
}
// WebIdentityEmptyRoleARNErr will occur if 'AWS_WEB_IDENTITY_TOKEN_FILE' was set but
// 'AWS_ROLE_ARN' was not set.
var WebIdentityEmptyRoleARNErr = awserr.New(stscreds.ErrCodeWebIdentity, "role ARN is not set", nil)
// WebIdentityEmptyTokenFilePathErr will occur if 'AWS_ROLE_ARN' was set but
// 'AWS_WEB_IDENTITY_TOKEN_FILE' was not set.
var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "token file path is not set", nil)
func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
filepath string,
roleARN, sessionName string,
credOptions *CredentialsProviderOptions,
) (*credentials.Credentials, error) {
if len(filepath) == 0 {
return nil, WebIdentityEmptyTokenFilePathErr
}
if len(roleARN) == 0 {
return nil, WebIdentityEmptyRoleARNErr
}
svc := sts.New(&Session{
Config: cfg,
Handlers: handlers.Copy(),
})
var optFns []func(*stscreds.WebIdentityRoleProvider)
if credOptions != nil && credOptions.WebIdentityRoleProviderOptions != nil {
optFns = append(optFns, credOptions.WebIdentityRoleProviderOptions)
}
p := stscreds.NewWebIdentityRoleProviderWithOptions(svc, roleARN, sessionName, stscreds.FetchTokenPath(filepath), optFns...)
return credentials.NewCredentials(p), nil
}
func resolveCredsFromProfile(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch {
case sharedCfg.SourceProfile != nil:
// Assume IAM role with credentials source from a different profile.
creds, err = resolveCredsFromProfile(cfg, envCfg,
*sharedCfg.SourceProfile, handlers, sessOpts,
)
case sharedCfg.Creds.HasKeys():
// Static Credentials from Shared Config/Credentials file.
creds = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
case len(sharedCfg.CredentialSource) != 0:
creds, err = resolveCredsFromSource(cfg, envCfg,
sharedCfg, handlers, sessOpts,
)
case len(sharedCfg.WebIdentityTokenFile) != 0:
// Credentials from Assume Web Identity token require an IAM Role, and
// that roll will be assumed. May be wrapped with another assume role
// via SourceProfile.
return assumeWebIdentity(cfg, handlers,
sharedCfg.WebIdentityTokenFile,
sharedCfg.RoleARN,
sharedCfg.RoleSessionName,
sessOpts.CredentialsProviderOptions,
)
case sharedCfg.hasSSOConfiguration():
creds, err = resolveSSOCredentials(cfg, sharedCfg, handlers)
case len(sharedCfg.CredentialProcess) != 0:
// Get credentials from CredentialProcess
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
default:
// Fallback to default credentials provider, include mock errors for
// the credential chain so user can identify why credentials failed to
// be retrieved.
creds = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{
Err: awserr.New("EnvAccessKeyNotFound",
"failed to find credentials in the environment.", nil),
},
&credProviderError{
Err: awserr.New("SharedCredsLoad",
fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil),
},
defaults.RemoteCredProvider(*cfg, handlers),
},
})
}
if err != nil {
return nil, err
}
if len(sharedCfg.RoleARN) > 0 {
cfgCp := *cfg
cfgCp.Credentials = creds
return credsFromAssumeRole(cfgCp, handlers, sharedCfg, sessOpts)
}
return creds, nil
}
func resolveSSOCredentials(cfg *aws.Config, sharedCfg sharedConfig, handlers request.Handlers) (*credentials.Credentials, error) {
if err := sharedCfg.validateSSOConfiguration(); err != nil {
return nil, err
}
cfgCopy := cfg.Copy()
cfgCopy.Region = &sharedCfg.SSORegion
return ssocreds.NewCredentials(
&Session{
Config: cfgCopy,
Handlers: handlers.Copy(),
},
sharedCfg.SSOAccountID,
sharedCfg.SSORoleName,
sharedCfg.SSOStartURL,
), nil
}
// valid credential source values
const (
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
)
func resolveCredsFromSource(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch sharedCfg.CredentialSource {
case credSourceEc2Metadata:
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
case credSourceEnvironment:
creds = credentials.NewStaticCredentialsFromCreds(envCfg.Creds)
case credSourceECSContainer:
if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 {
return nil, ErrSharedConfigECSContainerEnvVarEmpty
}
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
default:
return nil, ErrSharedConfigInvalidCredSource
}
return creds, nil
}
func credsFromAssumeRole(cfg aws.Config,
handlers request.Handlers,
sharedCfg sharedConfig,
sessOpts Options,
) (*credentials.Credentials, error) {
if len(sharedCfg.MFASerial) != 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return nil, AssumeRoleTokenProviderNotSetError{}
}
return stscreds.NewCredentials(
&Session{
Config: &cfg,
Handlers: handlers.Copy(),
},
sharedCfg.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.RoleSessionName
if sessOpts.AssumeRoleDuration == 0 &&
sharedCfg.AssumeRoleDuration != nil &&
*sharedCfg.AssumeRoleDuration/time.Minute > 15 {
opt.Duration = *sharedCfg.AssumeRoleDuration
} else if sessOpts.AssumeRoleDuration != 0 {
opt.Duration = sessOpts.AssumeRoleDuration
}
// Assume role with external ID
if len(sharedCfg.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.ExternalID)
}
// Assume role with MFA
if len(sharedCfg.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
}
},
), nil
}
// AssumeRoleTokenProviderNotSetError is an error returned when creating a
// session when the MFAToken option is not set when shared config is configured
// load assume a role with an MFA token.
type AssumeRoleTokenProviderNotSetError struct{}
// Code is the short id of the error.
func (e AssumeRoleTokenProviderNotSetError) Code() string {
return "AssumeRoleTokenProviderNotSetError"
}
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
}

View file

@ -0,0 +1,28 @@
//go:build go1.13
// +build go1.13
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCustomTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

View file

@ -0,0 +1,27 @@
//go:build !go1.13 && go1.7
// +build !go1.13,go1.7
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCustomTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

View file

@ -0,0 +1,23 @@
//go:build !go1.6 && go1.5
// +build !go1.6,go1.5
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCustomTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
}
}

View file

@ -0,0 +1,24 @@
//go:build !go1.7 && go1.6
// +build !go1.7,go1.6
package session
import (
"net"
"net/http"
"time"
)
// Transport that should be used when a custom CA bundle is specified with the
// SDK.
func getCustomTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

View file

@ -1,97 +1,93 @@
/* /*
Package session provides configuration for the SDK's service clients. Package session provides configuration for the SDK's service clients. Sessions
can be shared across service clients that share the same base configuration.
Sessions can be shared across all service clients that share the same base
configuration. The Session is built from the SDK's default configuration and
request handlers.
Sessions should be cached when possible, because creating a new Session will
load all configuration values from the environment, and config files each time
the Session is created. Sharing the Session value across all of your service
clients will ensure the configuration is loaded the fewest number of times possible.
Concurrency
Sessions are safe to use concurrently as long as the Session is not being Sessions are safe to use concurrently as long as the Session is not being
modified. The SDK will not modify the Session once the Session has been created. modified. Sessions should be cached when possible, because creating a new
Creating service clients concurrently from a shared Session is safe. Session will load all configuration values from the environment, and config
files each time the Session is created. Sharing the Session value across all of
your service clients will ensure the configuration is loaded the fewest number
of times possible.
Sessions from Shared Config Sessions options from Shared Config
Sessions can be created using the method above that will only load the
additional config if the AWS_SDK_LOAD_CONFIG environment variable is set.
Alternatively you can explicitly create a Session with shared config enabled.
To do this you can use NewSessionWithOptions to configure how the Session will
be created. Using the NewSessionWithOptions with SharedConfigState set to
SharedConfigEnable will create the session as if the AWS_SDK_LOAD_CONFIG
environment variable was set.
Creating Sessions
When creating Sessions optional aws.Config values can be passed in that will
override the default, or loaded config values the Session is being created
with. This allows you to provide additional, or case based, configuration
as needed.
By default NewSession will only load credentials from the shared credentials By default NewSession will only load credentials from the shared credentials
file (~/.aws/credentials). If the AWS_SDK_LOAD_CONFIG environment variable is file (~/.aws/credentials). If the AWS_SDK_LOAD_CONFIG environment variable is
set to a truthy value the Session will be created from the configuration set to a truthy value the Session will be created from the configuration
values from the shared config (~/.aws/config) and shared credentials values from the shared config (~/.aws/config) and shared credentials
(~/.aws/credentials) files. See the section Sessions from Shared Config for (~/.aws/credentials) files. Using the NewSessionWithOptions with
more information. SharedConfigState set to SharedConfigEnable will create the session as if the
AWS_SDK_LOAD_CONFIG environment variable was set.
Create a Session with the default config and request handlers. With credentials Credential and config loading order
region, and profile loaded from the environment and shared config automatically.
Requires the AWS_PROFILE to be set, or "default" is used. The Session will attempt to load configuration and credentials from the
environment, configuration files, and other credential sources. The order
configuration is loaded in is:
* Environment Variables
* Shared Credentials file
* Shared Configuration file (if SharedConfig is enabled)
* EC2 Instance Metadata (credentials only)
The Environment variables for credentials will have precedence over shared
config even if SharedConfig is enabled. To override this behavior, and use
shared config credentials instead specify the session.Options.Profile, (e.g.
when using credential_source=Environment to assume a role).
sess, err := session.NewSessionWithOptions(session.Options{
Profile: "myProfile",
})
Creating Sessions
Creating a Session without additional options will load credentials region, and
profile loaded from the environment and shared config automatically. See,
"Environment Variables" section for information on environment variables used
by Session.
// Create Session // Create Session
sess := session.Must(session.NewSession()) sess, err := session.NewSession()
When creating Sessions optional aws.Config values can be passed in that will
override the default, or loaded, config values the Session is being created
with. This allows you to provide additional, or case based, configuration
as needed.
// Create a Session with a custom region // Create a Session with a custom region
sess := session.Must(session.NewSession(&aws.Config{ sess, err := session.NewSession(&aws.Config{
Region: aws.String("us-east-1"), Region: aws.String("us-west-2"),
})) })
// Create a S3 client instance from a session Use NewSessionWithOptions to provide additional configuration driving how the
sess := session.Must(session.NewSession()) Session's configuration will be loaded. Such as, specifying shared config
profile, or override the shared config state, (AWS_SDK_LOAD_CONFIG).
svc := s3.New(sess)
Create Session With Option Overrides
In addition to NewSession, Sessions can be created using NewSessionWithOptions.
This func allows you to control and override how the Session will be created
through code instead of being driven by environment variables only.
Use NewSessionWithOptions when you want to provide the config profile, or
override the shared config state (AWS_SDK_LOAD_CONFIG).
// Equivalent to session.NewSession() // Equivalent to session.NewSession()
sess := session.Must(session.NewSessionWithOptions(session.Options{ sess, err := session.NewSessionWithOptions(session.Options{
// Options // Options
})) })
sess, err := session.NewSessionWithOptions(session.Options{
// Specify profile to load for the session's config // Specify profile to load for the session's config
sess := session.Must(session.NewSessionWithOptions(session.Options{
Profile: "profile_name", Profile: "profile_name",
}))
// Specify profile for config and region for requests // Provide SDK Config options, such as Region.
sess := session.Must(session.NewSessionWithOptions(session.Options{ Config: aws.Config{
Config: aws.Config{Region: aws.String("us-east-1")}, Region: aws.String("us-west-2"),
Profile: "profile_name", },
}))
// Force enable Shared Config support // Force enable Shared Config support
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable, SharedConfigState: session.SharedConfigEnable,
})) })
Adding Handlers Adding Handlers
You can add handlers to a session for processing HTTP requests. All service You can add handlers to a session to decorate API operation, (e.g. adding HTTP
clients that use the session inherit the handlers. For example, the following headers). All clients that use the Session receive a copy of the Session's
handler logs every request and its payload made by a service client: handlers. For example, the following request handler added to the Session logs
every requests made.
// Create a session, and add additional handlers for all service // Create a session, and add additional handlers for all service
// clients created with the Session to inherit. Adds logging handler. // clients created with the Session to inherit. Adds logging handler.
@ -99,22 +95,15 @@ handler logs every request and its payload made by a service client:
sess.Handlers.Send.PushFront(func(r *request.Request) { sess.Handlers.Send.PushFront(func(r *request.Request) {
// Log every request made and its payload // Log every request made and its payload
logger.Println("Request: %s/%s, Payload: %s", logger.Printf("Request: %s/%s, Params: %s",
r.ClientInfo.ServiceName, r.Operation, r.Params) r.ClientInfo.ServiceName, r.Operation, r.Params)
}) })
Deprecated "New" function
The New session function has been deprecated because it does not provide good
way to return errors that occur when loading the configuration files and values.
Because of this, NewSession was created so errors can be retrieved when
creating a session fails.
Shared Config Fields Shared Config Fields
By default the SDK will only load the shared credentials file's (~/.aws/credentials) By default the SDK will only load the shared credentials file's
credentials values, and all other config is provided by the environment variables, (~/.aws/credentials) credentials values, and all other config is provided by
SDK defaults, and user provided aws.Config values. the environment variables, SDK defaults, and user provided aws.Config values.
If the AWS_SDK_LOAD_CONFIG environment variable is set, or SharedConfigEnable If the AWS_SDK_LOAD_CONFIG environment variable is set, or SharedConfigEnable
option is used to create the Session the full shared config values will be option is used to create the Session the full shared config values will be
@ -125,24 +114,31 @@ files have the same format.
If both config files are present the configuration from both files will be If both config files are present the configuration from both files will be
read. The Session will be created from configuration values from the shared read. The Session will be created from configuration values from the shared
credentials file (~/.aws/credentials) over those in the shared config file (~/.aws/config). credentials file (~/.aws/credentials) over those in the shared config file
(~/.aws/config).
Credentials are the values the SDK should use for authenticating requests with Credentials are the values the SDK uses to authenticating requests with AWS
AWS Services. They are from a configuration file will need to include both Services. When specified in a file, both aws_access_key_id and
aws_access_key_id and aws_secret_access_key must be provided together in the aws_secret_access_key must be provided together in the same file to be
same file to be considered valid. The values will be ignored if not a complete considered valid. They will be ignored if both are not present.
group. aws_session_token is an optional field that can be provided if both of aws_session_token is an optional field that can be provided in addition to the
the other two fields are also provided. other two fields.
aws_access_key_id = AKID aws_access_key_id = AKID
aws_secret_access_key = SECRET aws_secret_access_key = SECRET
aws_session_token = TOKEN aws_session_token = TOKEN
Assume Role values allow you to configure the SDK to assume an IAM role using ; region only supported if SharedConfigEnabled.
a set of credentials provided in a config file via the source_profile field. region = us-east-1
Both "role_arn" and "source_profile" are required. The SDK supports assuming
a role with MFA token if the session option AssumeRoleTokenProvider Assume Role configuration
is set.
The role_arn field allows you to configure the SDK to assume an IAM role using
a set of credentials from another source. Such as when paired with static
credentials, "profile_source", "credential_process", or "credential_source"
fields. If "role_arn" is provided, a source of credentials must also be
specified, such as "source_profile", "credential_source", or
"credential_process".
role_arn = arn:aws:iam::<account_number>:role/<role_name> role_arn = arn:aws:iam::<account_number>:role/<role_name>
source_profile = profile_with_creds source_profile = profile_with_creds
@ -150,40 +146,16 @@ is set.
mfa_serial = <serial or mfa arn> mfa_serial = <serial or mfa arn>
role_session_name = session_name role_session_name = session_name
Region is the region the SDK should use for looking up AWS service endpoints
and signing requests.
region = us-east-1 The SDK supports assuming a role with MFA token. If "mfa_serial" is set, you
must also set the Session Option.AssumeRoleTokenProvider. The Session will fail
Assume Role with MFA token to load if the AssumeRoleTokenProvider is not specified.
To create a session with support for assuming an IAM role with MFA set the
session option AssumeRoleTokenProvider to a function that will prompt for the
MFA token code when the SDK assumes the role and refreshes the role's credentials.
This allows you to configure the SDK via the shared config to assumea role
with MFA tokens.
In order for the SDK to assume a role with MFA the SharedConfigState
session option must be set to SharedConfigEnable, or AWS_SDK_LOAD_CONFIG
environment variable set.
The shared configuration instructs the SDK to assume an IAM role with MFA
when the mfa_serial configuration field is set in the shared config
(~/.aws/config) or shared credentials (~/.aws/credentials) file.
If mfa_serial is set in the configuration, the SDK will assume the role, and
the AssumeRoleTokenProvider session option is not set an an error will
be returned when creating the session.
sess := session.Must(session.NewSessionWithOptions(session.Options{ sess := session.Must(session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: stscreds.StdinTokenProvider, AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
})) }))
// Create service client value configured for credentials To setup Assume Role outside of a session see the stscreds.AssumeRoleProvider
// from assumed role.
svc := s3.New(sess)
To setup assume role outside of a session see the stscrds.AssumeRoleProvider
documentation. documentation.
Environment Variables Environment Variables
@ -236,6 +208,8 @@ env values as well.
AWS_SDK_LOAD_CONFIG=1 AWS_SDK_LOAD_CONFIG=1
Custom Shared Config and Credential Files
Shared credentials file path can be set to instruct the SDK to use an alternative Shared credentials file path can be set to instruct the SDK to use an alternative
file for the shared credentials. If not set the file will be loaded from file for the shared credentials. If not set the file will be loaded from
$HOME/.aws/credentials on Linux/Unix based systems, and $HOME/.aws/credentials on Linux/Unix based systems, and
@ -250,6 +224,8 @@ $HOME/.aws/config on Linux/Unix based systems, and
AWS_CONFIG_FILE=$HOME/my_shared_config AWS_CONFIG_FILE=$HOME/my_shared_config
Custom CA Bundle
Path to a custom Credentials Authority (CA) bundle PEM file that the SDK Path to a custom Credentials Authority (CA) bundle PEM file that the SDK
will use instead of the default system's root CA bundle. Use this only will use instead of the default system's root CA bundle. Use this only
if you want to replace the CA bundle the SDK uses for TLS requests. if you want to replace the CA bundle the SDK uses for TLS requests.
@ -269,5 +245,123 @@ over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
Setting a custom HTTPClient in the aws.Config options will override this setting. Setting a custom HTTPClient in the aws.Config options will override this setting.
To use this option and custom HTTP client, the HTTP client needs to be provided To use this option and custom HTTP client, the HTTP client needs to be provided
when creating the session. Not the service client. when creating the session. Not the service client.
Custom Client TLS Certificate
The SDK supports the environment and session option being configured with
Client TLS certificates that are sent as a part of the client's TLS handshake
for client authentication. If used, both Cert and Key values are required. If
one is missing, or either fail to load the contents of the file an error will
be returned.
HTTP Client's Transport concrete implementation must be a http.Transport
or creating the session will fail.
AWS_SDK_GO_CLIENT_TLS_KEY=$HOME/my_client_key
AWS_SDK_GO_CLIENT_TLS_CERT=$HOME/my_client_cert
This can also be configured via the session.Options ClientTLSCert and ClientTLSKey.
sess, err := session.NewSessionWithOptions(session.Options{
ClientTLSCert: myCertFile,
ClientTLSKey: myKeyFile,
})
Custom EC2 IMDS Endpoint
The endpoint of the EC2 IMDS client can be configured via the environment
variable, AWS_EC2_METADATA_SERVICE_ENDPOINT when creating the client with a
Session. See Options.EC2IMDSEndpoint for more details.
AWS_EC2_METADATA_SERVICE_ENDPOINT=http://169.254.169.254
If using an URL with an IPv6 address literal, the IPv6 address
component must be enclosed in square brackets.
AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1]
The custom EC2 IMDS endpoint can also be specified via the Session options.
sess, err := session.NewSessionWithOptions(session.Options{
EC2MetadataEndpoint: "http://[::1]",
})
FIPS and DualStack Endpoints
The SDK can be configured to resolve an endpoint with certain capabilities such as FIPS and DualStack.
You can configure a FIPS endpoint using an environment variable, shared config ($HOME/.aws/config),
or programmatically.
To configure a FIPS endpoint set the environment variable set the AWS_USE_FIPS_ENDPOINT to true or false to enable
or disable FIPS endpoint resolution.
AWS_USE_FIPS_ENDPOINT=true
To configure a FIPS endpoint using shared config, set use_fips_endpoint to true or false to enable
or disable FIPS endpoint resolution.
[profile myprofile]
region=us-west-2
use_fips_endpoint=true
To configure a FIPS endpoint programmatically
// Option 1: Configure it on a session for all clients
sess, err := session.NewSessionWithOptions(session.Options{
UseFIPSEndpoint: endpoints.FIPSEndpointStateEnabled,
})
if err != nil {
// handle error
}
client := s3.New(sess)
// Option 2: Configure it per client
sess, err := session.NewSession()
if err != nil {
// handle error
}
client := s3.New(sess, &aws.Config{
UseFIPSEndpoint: endpoints.FIPSEndpointStateEnabled,
})
You can configure a DualStack endpoint using an environment variable, shared config ($HOME/.aws/config),
or programmatically.
To configure a DualStack endpoint set the environment variable set the AWS_USE_DUALSTACK_ENDPOINT to true or false to
enable or disable DualStack endpoint resolution.
AWS_USE_DUALSTACK_ENDPOINT=true
To configure a DualStack endpoint using shared config, set use_dualstack_endpoint to true or false to enable
or disable DualStack endpoint resolution.
[profile myprofile]
region=us-west-2
use_dualstack_endpoint=true
To configure a DualStack endpoint programmatically
// Option 1: Configure it on a session for all clients
sess, err := session.NewSessionWithOptions(session.Options{
UseDualStackEndpoint: endpoints.DualStackEndpointStateEnabled,
})
if err != nil {
// handle error
}
client := s3.New(sess)
// Option 2: Configure it per client
sess, err := session.NewSession()
if err != nil {
// handle error
}
client := s3.New(sess, &aws.Config{
UseDualStackEndpoint: endpoints.DualStackEndpointStateEnabled,
})
*/ */
package session package session

View file

@ -1,11 +1,15 @@
package session package session
import ( import (
"fmt"
"os" "os"
"strconv" "strconv"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
) )
// EnvProviderName provides a name of the provider when config is loaded from environment. // EnvProviderName provides a name of the provider when config is loaded from environment.
@ -79,7 +83,7 @@ type envConfig struct {
// AWS_CONFIG_FILE=$HOME/my_shared_config // AWS_CONFIG_FILE=$HOME/my_shared_config
SharedConfigFile string SharedConfigFile string
// Sets the path to a custom Credentials Authroity (CA) Bundle PEM file // Sets the path to a custom Credentials Authority (CA) Bundle PEM file
// that the SDK will use instead of the system's root CA bundle. // that the SDK will use instead of the system's root CA bundle.
// Only use this if you want to configure the SDK to use a custom set // Only use this if you want to configure the SDK to use a custom set
// of CAs. // of CAs.
@ -97,16 +101,96 @@ type envConfig struct {
// AWS_CA_BUNDLE=$HOME/my_custom_ca_bundle // AWS_CA_BUNDLE=$HOME/my_custom_ca_bundle
CustomCABundle string CustomCABundle string
// Sets the TLC client certificate that should be used by the SDK's HTTP transport
// when making requests. The certificate must be paired with a TLS client key file.
//
// AWS_SDK_GO_CLIENT_TLS_CERT=$HOME/my_client_cert
ClientTLSCert string
// Sets the TLC client key that should be used by the SDK's HTTP transport
// when making requests. The key must be paired with a TLS client certificate file.
//
// AWS_SDK_GO_CLIENT_TLS_KEY=$HOME/my_client_key
ClientTLSKey string
csmEnabled string csmEnabled string
CSMEnabled bool CSMEnabled *bool
CSMPort string CSMPort string
CSMHost string
CSMClientID string CSMClientID string
// Enables endpoint discovery via environment variables.
//
// AWS_ENABLE_ENDPOINT_DISCOVERY=true
EnableEndpointDiscovery *bool
enableEndpointDiscovery string
// Specifies the WebIdentity token the SDK should use to assume a role
// with.
//
// AWS_WEB_IDENTITY_TOKEN_FILE=file_path
WebIdentityTokenFilePath string
// Specifies the IAM role arn to use when assuming an role.
//
// AWS_ROLE_ARN=role_arn
RoleARN string
// Specifies the IAM role session name to use when assuming a role.
//
// AWS_ROLE_SESSION_NAME=session_name
RoleSessionName string
// Specifies the STS Regional Endpoint flag for the SDK to resolve the endpoint
// for a service.
//
// AWS_STS_REGIONAL_ENDPOINTS=regional
// This can take value as `regional` or `legacy`
STSRegionalEndpoint endpoints.STSRegionalEndpoint
// Specifies the S3 Regional Endpoint flag for the SDK to resolve the
// endpoint for a service.
//
// AWS_S3_US_EAST_1_REGIONAL_ENDPOINT=regional
// This can take value as `regional` or `legacy`
S3UsEast1RegionalEndpoint endpoints.S3UsEast1RegionalEndpoint
// Specifies if the S3 service should allow ARNs to direct the region
// the client's requests are sent to.
//
// AWS_S3_USE_ARN_REGION=true
S3UseARNRegion bool
// Specifies the EC2 Instance Metadata Service endpoint to use. If specified it overrides EC2IMDSEndpointMode.
//
// AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1]
EC2IMDSEndpoint string
// Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6)
//
// AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6
EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState
// Specifies that SDK clients must resolve a dual-stack endpoint for
// services.
//
// AWS_USE_DUALSTACK_ENDPOINT=true
UseDualStackEndpoint endpoints.DualStackEndpointState
// Specifies that SDK clients must resolve a FIPS endpoint for
// services.
//
// AWS_USE_FIPS_ENDPOINT=true
UseFIPSEndpoint endpoints.FIPSEndpointState
} }
var ( var (
csmEnabledEnvKey = []string{ csmEnabledEnvKey = []string{
"AWS_CSM_ENABLED", "AWS_CSM_ENABLED",
} }
csmHostEnvKey = []string{
"AWS_CSM_HOST",
}
csmPortEnvKey = []string{ csmPortEnvKey = []string{
"AWS_CSM_PORT", "AWS_CSM_PORT",
} }
@ -125,6 +209,10 @@ var (
"AWS_SESSION_TOKEN", "AWS_SESSION_TOKEN",
} }
enableEndpointDiscoveryEnvKey = []string{
"AWS_ENABLE_ENDPOINT_DISCOVERY",
}
regionEnvKeys = []string{ regionEnvKeys = []string{
"AWS_REGION", "AWS_REGION",
"AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set "AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set
@ -139,6 +227,45 @@ var (
sharedConfigFileEnvKey = []string{ sharedConfigFileEnvKey = []string{
"AWS_CONFIG_FILE", "AWS_CONFIG_FILE",
} }
webIdentityTokenFilePathEnvKey = []string{
"AWS_WEB_IDENTITY_TOKEN_FILE",
}
roleARNEnvKey = []string{
"AWS_ROLE_ARN",
}
roleSessionNameEnvKey = []string{
"AWS_ROLE_SESSION_NAME",
}
stsRegionalEndpointKey = []string{
"AWS_STS_REGIONAL_ENDPOINTS",
}
s3UsEast1RegionalEndpoint = []string{
"AWS_S3_US_EAST_1_REGIONAL_ENDPOINT",
}
s3UseARNRegionEnvKey = []string{
"AWS_S3_USE_ARN_REGION",
}
ec2IMDSEndpointEnvKey = []string{
"AWS_EC2_METADATA_SERVICE_ENDPOINT",
}
ec2IMDSEndpointModeEnvKey = []string{
"AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE",
}
useCABundleKey = []string{
"AWS_CA_BUNDLE",
}
useClientTLSCert = []string{
"AWS_SDK_GO_CLIENT_TLS_CERT",
}
useClientTLSKey = []string{
"AWS_SDK_GO_CLIENT_TLS_KEY",
}
awsUseDualStackEndpoint = []string{
"AWS_USE_DUALSTACK_ENDPOINT",
}
awsUseFIPSEndpoint = []string{
"AWS_USE_FIPS_ENDPOINT",
}
) )
// loadEnvConfig retrieves the SDK's environment configuration. // loadEnvConfig retrieves the SDK's environment configuration.
@ -147,7 +274,7 @@ var (
// If the environment variable `AWS_SDK_LOAD_CONFIG` is set to a truthy value // If the environment variable `AWS_SDK_LOAD_CONFIG` is set to a truthy value
// the shared SDK config will be loaded in addition to the SDK's specific // the shared SDK config will be loaded in addition to the SDK's specific
// configuration values. // configuration values.
func loadEnvConfig() envConfig { func loadEnvConfig() (envConfig, error) {
enableSharedConfig, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG")) enableSharedConfig, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG"))
return envConfigLoad(enableSharedConfig) return envConfigLoad(enableSharedConfig)
} }
@ -158,30 +285,42 @@ func loadEnvConfig() envConfig {
// Loads the shared configuration in addition to the SDK's specific configuration. // Loads the shared configuration in addition to the SDK's specific configuration.
// This will load the same values as `loadEnvConfig` if the `AWS_SDK_LOAD_CONFIG` // This will load the same values as `loadEnvConfig` if the `AWS_SDK_LOAD_CONFIG`
// environment variable is set. // environment variable is set.
func loadSharedEnvConfig() envConfig { func loadSharedEnvConfig() (envConfig, error) {
return envConfigLoad(true) return envConfigLoad(true)
} }
func envConfigLoad(enableSharedConfig bool) envConfig { func envConfigLoad(enableSharedConfig bool) (envConfig, error) {
cfg := envConfig{} cfg := envConfig{}
cfg.EnableSharedConfig = enableSharedConfig cfg.EnableSharedConfig = enableSharedConfig
setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey) // Static environment credentials
setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey) var creds credentials.Value
setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey) setFromEnvVal(&creds.AccessKeyID, credAccessEnvKey)
setFromEnvVal(&creds.SecretAccessKey, credSecretEnvKey)
setFromEnvVal(&creds.SessionToken, credSessionEnvKey)
if creds.HasKeys() {
// Require logical grouping of credentials
creds.ProviderName = EnvProviderName
cfg.Creds = creds
}
// Role Metadata
setFromEnvVal(&cfg.RoleARN, roleARNEnvKey)
setFromEnvVal(&cfg.RoleSessionName, roleSessionNameEnvKey)
// Web identity environment variables
setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey)
// CSM environment variables // CSM environment variables
setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey) setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey)
setFromEnvVal(&cfg.CSMHost, csmHostEnvKey)
setFromEnvVal(&cfg.CSMPort, csmPortEnvKey) setFromEnvVal(&cfg.CSMPort, csmPortEnvKey)
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey) setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
cfg.CSMEnabled = len(cfg.csmEnabled) > 0
// Require logical grouping of credentials if len(cfg.csmEnabled) != 0 {
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 { v, _ := strconv.ParseBool(cfg.csmEnabled)
cfg.Creds = credentials.Value{} cfg.CSMEnabled = &v
} else {
cfg.Creds.ProviderName = EnvProviderName
} }
regionKeys := regionEnvKeys regionKeys := regionEnvKeys
@ -194,6 +333,12 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
setFromEnvVal(&cfg.Region, regionKeys) setFromEnvVal(&cfg.Region, regionKeys)
setFromEnvVal(&cfg.Profile, profileKeys) setFromEnvVal(&cfg.Profile, profileKeys)
// endpoint discovery is in reference to it being enabled.
setFromEnvVal(&cfg.enableEndpointDiscovery, enableEndpointDiscoveryEnvKey)
if len(cfg.enableEndpointDiscovery) > 0 {
cfg.EnableEndpointDiscovery = aws.Bool(cfg.enableEndpointDiscovery != "false")
}
setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey) setFromEnvVal(&cfg.SharedCredentialsFile, sharedCredsFileEnvKey)
setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey) setFromEnvVal(&cfg.SharedConfigFile, sharedConfigFileEnvKey)
@ -204,16 +349,123 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
cfg.SharedConfigFile = defaults.SharedConfigFilename() cfg.SharedConfigFile = defaults.SharedConfigFilename()
} }
cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE") setFromEnvVal(&cfg.CustomCABundle, useCABundleKey)
setFromEnvVal(&cfg.ClientTLSCert, useClientTLSCert)
setFromEnvVal(&cfg.ClientTLSKey, useClientTLSKey)
return cfg var err error
// STS Regional Endpoint variable
for _, k := range stsRegionalEndpointKey {
if v := os.Getenv(k); len(v) != 0 {
cfg.STSRegionalEndpoint, err = endpoints.GetSTSRegionalEndpoint(v)
if err != nil {
return cfg, fmt.Errorf("failed to load, %v from env config, %v", k, err)
}
}
}
// S3 Regional Endpoint variable
for _, k := range s3UsEast1RegionalEndpoint {
if v := os.Getenv(k); len(v) != 0 {
cfg.S3UsEast1RegionalEndpoint, err = endpoints.GetS3UsEast1RegionalEndpoint(v)
if err != nil {
return cfg, fmt.Errorf("failed to load, %v from env config, %v", k, err)
}
}
}
var s3UseARNRegion string
setFromEnvVal(&s3UseARNRegion, s3UseARNRegionEnvKey)
if len(s3UseARNRegion) != 0 {
switch {
case strings.EqualFold(s3UseARNRegion, "false"):
cfg.S3UseARNRegion = false
case strings.EqualFold(s3UseARNRegion, "true"):
cfg.S3UseARNRegion = true
default:
return envConfig{}, fmt.Errorf(
"invalid value for environment variable, %s=%s, need true or false",
s3UseARNRegionEnvKey[0], s3UseARNRegion)
}
}
setFromEnvVal(&cfg.EC2IMDSEndpoint, ec2IMDSEndpointEnvKey)
if err := setEC2IMDSEndpointMode(&cfg.EC2IMDSEndpointMode, ec2IMDSEndpointModeEnvKey); err != nil {
return envConfig{}, err
}
if err := setUseDualStackEndpointFromEnvVal(&cfg.UseDualStackEndpoint, awsUseDualStackEndpoint); err != nil {
return cfg, err
}
if err := setUseFIPSEndpointFromEnvVal(&cfg.UseFIPSEndpoint, awsUseFIPSEndpoint); err != nil {
return cfg, err
}
return cfg, nil
} }
func setFromEnvVal(dst *string, keys []string) { func setFromEnvVal(dst *string, keys []string) {
for _, k := range keys { for _, k := range keys {
if v := os.Getenv(k); len(v) > 0 { if v := os.Getenv(k); len(v) != 0 {
*dst = v *dst = v
break break
} }
} }
} }
func setEC2IMDSEndpointMode(mode *endpoints.EC2IMDSEndpointModeState, keys []string) error {
for _, k := range keys {
value := os.Getenv(k)
if len(value) == 0 {
continue
}
if err := mode.SetFromString(value); err != nil {
return fmt.Errorf("invalid value for environment variable, %s=%s, %v", k, value, err)
}
return nil
}
return nil
}
func setUseDualStackEndpointFromEnvVal(dst *endpoints.DualStackEndpointState, keys []string) error {
for _, k := range keys {
value := os.Getenv(k)
if len(value) == 0 {
continue // skip if empty
}
switch {
case strings.EqualFold(value, "true"):
*dst = endpoints.DualStackEndpointStateEnabled
case strings.EqualFold(value, "false"):
*dst = endpoints.DualStackEndpointStateDisabled
default:
return fmt.Errorf(
"invalid value for environment variable, %s=%s, need true, false",
k, value)
}
}
return nil
}
func setUseFIPSEndpointFromEnvVal(dst *endpoints.FIPSEndpointState, keys []string) error {
for _, k := range keys {
value := os.Getenv(k)
if len(value) == 0 {
continue // skip if empty
}
switch {
case strings.EqualFold(value, "true"):
*dst = endpoints.FIPSEndpointStateEnabled
case strings.EqualFold(value, "false"):
*dst = endpoints.FIPSEndpointStateDisabled
default:
return fmt.Errorf(
"invalid value for environment variable, %s=%s, need true, false",
k, value)
}
}
return nil
}

View file

@ -8,19 +8,44 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/corehandlers" "github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/csm" "github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
) )
const (
// ErrCodeSharedConfig represents an error that occurs in the shared
// configuration logic
ErrCodeSharedConfig = "SharedConfigErr"
// ErrCodeLoadCustomCABundle error code for unable to load custom CA bundle.
ErrCodeLoadCustomCABundle = "LoadCustomCABundleError"
// ErrCodeLoadClientTLSCert error code for unable to load client TLS
// certificate or key
ErrCodeLoadClientTLSCert = "LoadClientTLSCertError"
)
// ErrSharedConfigSourceCollision will be returned if a section contains both
// source_profile and credential_source
var ErrSharedConfigSourceCollision = awserr.New(ErrCodeSharedConfig, "only one credential type may be specified per profile: source profile, credential source, credential process, web identity token, or sso", nil)
// ErrSharedConfigECSContainerEnvVarEmpty will be returned if the environment
// variables are empty and Environment was set as the credential source
var ErrSharedConfigECSContainerEnvVarEmpty = awserr.New(ErrCodeSharedConfig, "EcsContainer was specified as the credential_source, but 'AWS_CONTAINER_CREDENTIALS_RELATIVE_URI' was not set", nil)
// ErrSharedConfigInvalidCredSource will be returned if an invalid credential source was provided
var ErrSharedConfigInvalidCredSource = awserr.New(ErrCodeSharedConfig, "credential source values must be EcsContainer, Ec2InstanceMetadata, or Environment", nil)
// A Session provides a central location to create service clients from and // A Session provides a central location to create service clients from and
// store configurations and request handlers for those services. // store configurations and request handlers for those services.
// //
@ -31,6 +56,8 @@ import (
type Session struct { type Session struct {
Config *aws.Config Config *aws.Config
Handlers request.Handlers Handlers request.Handlers
options Options
} }
// New creates a new instance of the handlers merging in the provided configs // New creates a new instance of the handlers merging in the provided configs
@ -56,7 +83,7 @@ type Session struct {
// func is called instead of waiting to receive an error until a request is made. // func is called instead of waiting to receive an error until a request is made.
func New(cfgs ...*aws.Config) *Session { func New(cfgs ...*aws.Config) *Session {
// load initial config from environment // load initial config from environment
envCfg := loadEnvConfig() envCfg, envErr := loadEnvConfig()
if envCfg.EnableSharedConfig { if envCfg.EnableSharedConfig {
var cfg aws.Config var cfg aws.Config
@ -76,19 +103,28 @@ func New(cfgs ...*aws.Config) *Session {
// Session creation failed, need to report the error and prevent // Session creation failed, need to report the error and prevent
// any requests from succeeding. // any requests from succeeding.
s = &Session{Config: defaults.Config()} s = &Session{Config: defaults.Config()}
s.Config.MergeIn(cfgs...) s.logDeprecatedNewSessionError(msg, err, cfgs)
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
s.Handlers.Validate.PushBack(func(r *request.Request) {
r.Error = err
})
} }
return s return s
} }
s := deprecatedNewSession(cfgs...) s := deprecatedNewSession(envCfg, cfgs...)
if envCfg.CSMEnabled { if envErr != nil {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger) msg := "failed to load env config"
s.logDeprecatedNewSessionError(msg, envErr, cfgs)
}
if csmCfg, err := loadCSMConfig(envCfg, []string{}); err != nil {
if l := s.Config.Logger; l != nil {
l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
}
} else if csmCfg.Enabled {
err := enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
if err != nil {
msg := "failed to enable CSM"
s.logDeprecatedNewSessionError(msg, err, cfgs)
}
} }
return s return s
@ -107,7 +143,7 @@ func New(cfgs ...*aws.Config) *Session {
// to be built with retrieving credentials with AssumeRole set in the config. // to be built with retrieving credentials with AssumeRole set in the config.
// //
// See the NewSessionWithOptions func for information on how to override or // See the NewSessionWithOptions func for information on how to override or
// control through code how the Session will be created. Such as specifying the // control through code how the Session will be created, such as specifying the
// config profile, and controlling if shared config is enabled or not. // config profile, and controlling if shared config is enabled or not.
func NewSession(cfgs ...*aws.Config) (*Session, error) { func NewSession(cfgs ...*aws.Config) (*Session, error) {
opts := Options{} opts := Options{}
@ -191,20 +227,88 @@ type Options struct {
// the config enables assume role wit MFA via the mfa_serial field. // the config enables assume role wit MFA via the mfa_serial field.
AssumeRoleTokenProvider func() (string, error) AssumeRoleTokenProvider func() (string, error)
// When the SDK's shared config is configured to assume a role this option
// may be provided to set the expiry duration of the STS credentials.
// Defaults to 15 minutes if not set as documented in the
// stscreds.AssumeRoleProvider.
AssumeRoleDuration time.Duration
// Reader for a custom Credentials Authority (CA) bundle in PEM format that // Reader for a custom Credentials Authority (CA) bundle in PEM format that
// the SDK will use instead of the default system's root CA bundle. Use this // the SDK will use instead of the default system's root CA bundle. Use this
// only if you want to replace the CA bundle the SDK uses for TLS requests. // only if you want to replace the CA bundle the SDK uses for TLS requests.
// //
// Enabling this option will attempt to merge the Transport into the SDK's HTTP // HTTP Client's Transport concrete implementation must be a http.Transport
// client. If the client's Transport is not a http.Transport an error will be // or creating the session will fail.
// returned. If the Transport's TLS config is set this option will cause the SDK //
// If the Transport's TLS config is set this option will cause the SDK
// to overwrite the Transport's TLS config's RootCAs value. If the CA // to overwrite the Transport's TLS config's RootCAs value. If the CA
// bundle reader contains multiple certificates all of them will be loaded. // bundle reader contains multiple certificates all of them will be loaded.
// //
// The Session option CustomCABundle is also available when creating sessions // Can also be specified via the environment variable:
// to also enable this feature. CustomCABundle session option field has priority //
// over the AWS_CA_BUNDLE environment variable, and will be used if both are set. // AWS_CA_BUNDLE=$HOME/ca_bundle
//
// Can also be specified via the shared config field:
//
// ca_bundle = $HOME/ca_bundle
CustomCABundle io.Reader CustomCABundle io.Reader
// Reader for the TLC client certificate that should be used by the SDK's
// HTTP transport when making requests. The certificate must be paired with
// a TLS client key file. Will be ignored if both are not provided.
//
// HTTP Client's Transport concrete implementation must be a http.Transport
// or creating the session will fail.
//
// Can also be specified via the environment variable:
//
// AWS_SDK_GO_CLIENT_TLS_CERT=$HOME/my_client_cert
ClientTLSCert io.Reader
// Reader for the TLC client key that should be used by the SDK's HTTP
// transport when making requests. The key must be paired with a TLS client
// certificate file. Will be ignored if both are not provided.
//
// HTTP Client's Transport concrete implementation must be a http.Transport
// or creating the session will fail.
//
// Can also be specified via the environment variable:
//
// AWS_SDK_GO_CLIENT_TLS_KEY=$HOME/my_client_key
ClientTLSKey io.Reader
// The handlers that the session and all API clients will be created with.
// This must be a complete set of handlers. Use the defaults.Handlers()
// function to initialize this value before changing the handlers to be
// used by the SDK.
Handlers request.Handlers
// Allows specifying a custom endpoint to be used by the EC2 IMDS client
// when making requests to the EC2 IMDS API. The endpoint value should
// include the URI scheme. If the scheme is not present it will be defaulted to http.
//
// If unset, will the EC2 IMDS client will use its default endpoint.
//
// Can also be specified via the environment variable,
// AWS_EC2_METADATA_SERVICE_ENDPOINT.
//
// AWS_EC2_METADATA_SERVICE_ENDPOINT=http://169.254.169.254
//
// If using an URL with an IPv6 address literal, the IPv6 address
// component must be enclosed in square brackets.
//
// AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1]
EC2IMDSEndpoint string
// Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6)
//
// AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6
EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState
// Specifies options for creating credential providers.
// These are only used if the aws.Config does not already
// include credentials.
CredentialsProviderOptions *CredentialsProviderOptions
} }
// NewSessionWithOptions returns a new Session created from SDK defaults, config files, // NewSessionWithOptions returns a new Session created from SDK defaults, config files,
@ -238,13 +342,20 @@ type Options struct {
// })) // }))
func NewSessionWithOptions(opts Options) (*Session, error) { func NewSessionWithOptions(opts Options) (*Session, error) {
var envCfg envConfig var envCfg envConfig
var err error
if opts.SharedConfigState == SharedConfigEnable { if opts.SharedConfigState == SharedConfigEnable {
envCfg = loadSharedEnvConfig() envCfg, err = loadSharedEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to load shared config, %v", err)
}
} else { } else {
envCfg = loadEnvConfig() envCfg, err = loadEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to load environment config, %v", err)
}
} }
if len(opts.Profile) > 0 { if len(opts.Profile) != 0 {
envCfg.Profile = opts.Profile envCfg.Profile = opts.Profile
} }
@ -255,17 +366,6 @@ func NewSessionWithOptions(opts Options) (*Session, error) {
envCfg.EnableSharedConfig = true envCfg.EnableSharedConfig = true
} }
// Only use AWS_CA_BUNDLE if session option is not provided.
if len(envCfg.CustomCABundle) != 0 && opts.CustomCABundle == nil {
f, err := os.Open(envCfg.CustomCABundle)
if err != nil {
return nil, awserr.New("LoadCustomCABundleError",
"failed to open custom CA bundle PEM file", err)
}
defer f.Close()
opts.CustomCABundle = f
}
return newSession(opts, envCfg, &opts.Config) return newSession(opts, envCfg, &opts.Config)
} }
@ -284,7 +384,29 @@ func Must(sess *Session, err error) *Session {
return sess return sess
} }
func deprecatedNewSession(cfgs ...*aws.Config) *Session { // Wraps the endpoint resolver with a resolver that will return a custom
// endpoint for EC2 IMDS.
func wrapEC2IMDSEndpoint(resolver endpoints.Resolver, endpoint string, mode endpoints.EC2IMDSEndpointModeState) endpoints.Resolver {
return endpoints.ResolverFunc(
func(service, region string, opts ...func(*endpoints.Options)) (
endpoints.ResolvedEndpoint, error,
) {
if service == ec2MetadataServiceID && len(endpoint) > 0 {
return endpoints.ResolvedEndpoint{
URL: endpoint,
SigningName: ec2MetadataServiceID,
SigningRegion: region,
}, nil
} else if service == ec2MetadataServiceID {
opts = append(opts, func(o *endpoints.Options) {
o.EC2MetadataEndpointMode = mode
})
}
return resolver.EndpointFor(service, region, opts...)
})
}
func deprecatedNewSession(envCfg envConfig, cfgs ...*aws.Config) *Session {
cfg := defaults.Config() cfg := defaults.Config()
handlers := defaults.Handlers() handlers := defaults.Handlers()
@ -296,6 +418,11 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session {
// endpoints for service client configurations. // endpoints for service client configurations.
cfg.EndpointResolver = endpoints.DefaultResolver() cfg.EndpointResolver = endpoints.DefaultResolver()
} }
if !(len(envCfg.EC2IMDSEndpoint) == 0 && envCfg.EC2IMDSEndpointMode == endpoints.EC2IMDSEndpointModeStateUnset) {
cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, envCfg.EC2IMDSEndpoint, envCfg.EC2IMDSEndpointMode)
}
cfg.Credentials = defaults.CredChain(cfg, handlers) cfg.Credentials = defaults.CredChain(cfg, handlers)
// Reapply any passed in configs to override credentials if set // Reapply any passed in configs to override credentials if set
@ -304,33 +431,42 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session {
s := &Session{ s := &Session{
Config: cfg, Config: cfg,
Handlers: handlers, Handlers: handlers,
options: Options{
EC2IMDSEndpoint: envCfg.EC2IMDSEndpoint,
},
} }
initHandlers(s) initHandlers(s)
return s return s
} }
func enableCSM(handlers *request.Handlers, clientID string, port string, logger aws.Logger) { func enableCSM(handlers *request.Handlers, cfg csmConfig, logger aws.Logger) error {
if logger != nil {
logger.Log("Enabling CSM") logger.Log("Enabling CSM")
if len(port) == 0 {
port = csm.DefaultPort
} }
r, err := csm.Start(clientID, "127.0.0.1:"+port) r, err := csm.Start(cfg.ClientID, csm.AddressWithDefaults(cfg.Host, cfg.Port))
if err != nil { if err != nil {
return return err
} }
r.InjectHandlers(handlers) r.InjectHandlers(handlers)
return nil
} }
func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) { func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
cfg := defaults.Config() cfg := defaults.Config()
handlers := defaults.Handlers()
handlers := opts.Handlers
if handlers.IsEmpty() {
handlers = defaults.Handlers()
}
// Get a merged version of the user provided config to determine if // Get a merged version of the user provided config to determine if
// credentials were. // credentials were.
userCfg := &aws.Config{} userCfg := &aws.Config{}
userCfg.MergeIn(cfgs...) userCfg.MergeIn(cfgs...)
cfg.MergeIn(userCfg)
// Ordered config files will be loaded in with later files overwriting // Ordered config files will be loaded in with later files overwriting
// previous config file values. // previous config file values.
@ -347,28 +483,42 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
} }
// Load additional config from file(s) // Load additional config from file(s)
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles) sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles, envCfg.EnableSharedConfig)
if err != nil { if err != nil {
if len(envCfg.Profile) == 0 && !envCfg.EnableSharedConfig && (envCfg.Creds.HasKeys() || userCfg.Credentials != nil) {
// Special case where the user has not explicitly specified an AWS_PROFILE,
// or session.Options.profile, shared config is not enabled, and the
// environment has credentials, allow the shared config file to fail to
// load since the user has already provided credentials, and nothing else
// is required to be read file. Github(aws/aws-sdk-go#2455)
} else if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
return nil, err return nil, err
} }
}
if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil { if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil {
return nil, err return nil, err
} }
if err := setTLSOptions(&opts, cfg, envCfg, sharedCfg); err != nil {
return nil, err
}
s := &Session{ s := &Session{
Config: cfg, Config: cfg,
Handlers: handlers, Handlers: handlers,
options: opts,
} }
initHandlers(s) initHandlers(s)
if envCfg.CSMEnabled {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger)
}
// Setup HTTP client with custom cert bundle if enabled if csmCfg, err := loadCSMConfig(envCfg, cfgFiles); err != nil {
if opts.CustomCABundle != nil { if l := s.Config.Logger; l != nil {
if err := loadCustomCABundle(s, opts.CustomCABundle); err != nil { l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
}
} else if csmCfg.Enabled {
err = enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
if err != nil {
return nil, err return nil, err
} }
} }
@ -376,19 +526,123 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
return s, nil return s, nil
} }
func loadCustomCABundle(s *Session, bundle io.Reader) error { type csmConfig struct {
Enabled bool
Host string
Port string
ClientID string
}
var csmProfileName = "aws_csm"
func loadCSMConfig(envCfg envConfig, cfgFiles []string) (csmConfig, error) {
if envCfg.CSMEnabled != nil {
if *envCfg.CSMEnabled {
return csmConfig{
Enabled: true,
ClientID: envCfg.CSMClientID,
Host: envCfg.CSMHost,
Port: envCfg.CSMPort,
}, nil
}
return csmConfig{}, nil
}
sharedCfg, err := loadSharedConfig(csmProfileName, cfgFiles, false)
if err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
return csmConfig{}, err
}
}
if sharedCfg.CSMEnabled != nil && *sharedCfg.CSMEnabled == true {
return csmConfig{
Enabled: true,
ClientID: sharedCfg.CSMClientID,
Host: sharedCfg.CSMHost,
Port: sharedCfg.CSMPort,
}, nil
}
return csmConfig{}, nil
}
func setTLSOptions(opts *Options, cfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig) error {
// CA Bundle can be specified in both environment variable shared config file.
var caBundleFilename = envCfg.CustomCABundle
if len(caBundleFilename) == 0 {
caBundleFilename = sharedCfg.CustomCABundle
}
// Only use environment value if session option is not provided.
customTLSOptions := map[string]struct {
filename string
field *io.Reader
errCode string
}{
"custom CA bundle PEM": {filename: caBundleFilename, field: &opts.CustomCABundle, errCode: ErrCodeLoadCustomCABundle},
"custom client TLS cert": {filename: envCfg.ClientTLSCert, field: &opts.ClientTLSCert, errCode: ErrCodeLoadClientTLSCert},
"custom client TLS key": {filename: envCfg.ClientTLSKey, field: &opts.ClientTLSKey, errCode: ErrCodeLoadClientTLSCert},
}
for name, v := range customTLSOptions {
if len(v.filename) != 0 && *v.field == nil {
f, err := os.Open(v.filename)
if err != nil {
return awserr.New(v.errCode, fmt.Sprintf("failed to open %s file", name), err)
}
defer f.Close()
*v.field = f
}
}
// Setup HTTP client with custom cert bundle if enabled
if opts.CustomCABundle != nil {
if err := loadCustomCABundle(cfg.HTTPClient, opts.CustomCABundle); err != nil {
return err
}
}
// Setup HTTP client TLS certificate and key for client TLS authentication.
if opts.ClientTLSCert != nil && opts.ClientTLSKey != nil {
if err := loadClientTLSCert(cfg.HTTPClient, opts.ClientTLSCert, opts.ClientTLSKey); err != nil {
return err
}
} else if opts.ClientTLSCert == nil && opts.ClientTLSKey == nil {
// Do nothing if neither values are available.
} else {
return awserr.New(ErrCodeLoadClientTLSCert,
fmt.Sprintf("client TLS cert(%t) and key(%t) must both be provided",
opts.ClientTLSCert != nil, opts.ClientTLSKey != nil), nil)
}
return nil
}
func getHTTPTransport(client *http.Client) (*http.Transport, error) {
var t *http.Transport var t *http.Transport
switch v := s.Config.HTTPClient.Transport.(type) { switch v := client.Transport.(type) {
case *http.Transport: case *http.Transport:
t = v t = v
default: default:
if s.Config.HTTPClient.Transport != nil { if client.Transport != nil {
return awserr.New("LoadCustomCABundleError", return nil, fmt.Errorf("unsupported transport, %T", client.Transport)
"unable to load custom CA bundle, HTTPClient's transport unsupported type", nil)
} }
} }
if t == nil { if t == nil {
t = &http.Transport{} // Nil transport implies `http.DefaultTransport` should be used. Since
// the SDK cannot modify, nor copy the `DefaultTransport` specifying
// the values the next closest behavior.
t = getCustomTransport()
}
return t, nil
}
func loadCustomCABundle(client *http.Client, bundle io.Reader) error {
t, err := getHTTPTransport(client)
if err != nil {
return awserr.New(ErrCodeLoadCustomCABundle,
"unable to load custom CA bundle, HTTPClient's transport unsupported type", err)
} }
p, err := loadCertPool(bundle) p, err := loadCertPool(bundle)
@ -400,7 +654,7 @@ func loadCustomCABundle(s *Session, bundle io.Reader) error {
} }
t.TLSClientConfig.RootCAs = p t.TLSClientConfig.RootCAs = p
s.Config.HTTPClient.Transport = t client.Transport = t
return nil return nil
} }
@ -408,22 +662,62 @@ func loadCustomCABundle(s *Session, bundle io.Reader) error {
func loadCertPool(r io.Reader) (*x509.CertPool, error) { func loadCertPool(r io.Reader) (*x509.CertPool, error) {
b, err := ioutil.ReadAll(r) b, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
return nil, awserr.New("LoadCustomCABundleError", return nil, awserr.New(ErrCodeLoadCustomCABundle,
"failed to read custom CA bundle PEM file", err) "failed to read custom CA bundle PEM file", err)
} }
p := x509.NewCertPool() p := x509.NewCertPool()
if !p.AppendCertsFromPEM(b) { if !p.AppendCertsFromPEM(b) {
return nil, awserr.New("LoadCustomCABundleError", return nil, awserr.New(ErrCodeLoadCustomCABundle,
"failed to load custom CA bundle PEM file", err) "failed to load custom CA bundle PEM file", err)
} }
return p, nil return p, nil
} }
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, sessOpts Options) error { func loadClientTLSCert(client *http.Client, certFile, keyFile io.Reader) error {
// Merge in user provided configuration t, err := getHTTPTransport(client)
cfg.MergeIn(userCfg) if err != nil {
return awserr.New(ErrCodeLoadClientTLSCert,
"unable to get usable HTTP transport from client", err)
}
cert, err := ioutil.ReadAll(certFile)
if err != nil {
return awserr.New(ErrCodeLoadClientTLSCert,
"unable to get read client TLS cert file", err)
}
key, err := ioutil.ReadAll(keyFile)
if err != nil {
return awserr.New(ErrCodeLoadClientTLSCert,
"unable to get read client TLS key file", err)
}
clientCert, err := tls.X509KeyPair(cert, key)
if err != nil {
return awserr.New(ErrCodeLoadClientTLSCert,
"unable to load x509 key pair from client cert", err)
}
tlsCfg := t.TLSClientConfig
if tlsCfg == nil {
tlsCfg = &tls.Config{}
}
tlsCfg.Certificates = append(tlsCfg.Certificates, clientCert)
t.TLSClientConfig = tlsCfg
client.Transport = t
return nil
}
func mergeConfigSrcs(cfg, userCfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) error {
// Region if not already set by user // Region if not already set by user
if len(aws.StringValue(cfg.Region)) == 0 { if len(aws.StringValue(cfg.Region)) == 0 {
@ -434,101 +728,109 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
} }
} }
// Configure credentials if not already set if cfg.EnableEndpointDiscovery == nil {
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil { if envCfg.EnableEndpointDiscovery != nil {
if len(envCfg.Creds.AccessKeyID) > 0 { cfg.WithEndpointDiscovery(*envCfg.EnableEndpointDiscovery)
cfg.Credentials = credentials.NewStaticCredentialsFromCreds( } else if envCfg.EnableSharedConfig && sharedCfg.EnableEndpointDiscovery != nil {
envCfg.Creds, cfg.WithEndpointDiscovery(*sharedCfg.EnableEndpointDiscovery)
)
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil {
cfgCp := *cfg
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.AssumeRoleSource.Creds,
)
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return AssumeRoleTokenProviderNotSetError{}
} }
cfg.Credentials = stscreds.NewCredentials(
&Session{
Config: &cfgCp,
Handlers: handlers.Copy(),
},
sharedCfg.AssumeRole.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName
// Assume role with external ID
if len(sharedCfg.AssumeRole.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID)
} }
// Assume role with MFA // Regional Endpoint flag for STS endpoint resolving
if len(sharedCfg.AssumeRole.MFASerial) > 0 { mergeSTSRegionalEndpointConfig(cfg, []endpoints.STSRegionalEndpoint{
opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial) userCfg.STSRegionalEndpoint,
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider envCfg.STSRegionalEndpoint,
} sharedCfg.STSRegionalEndpoint,
}, endpoints.LegacySTSEndpoint,
)
} else if len(sharedCfg.Creds.AccessKeyID) > 0 {
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
} else {
// Fallback to default credentials provider, include mock errors
// for the credential chain so user can identify why credentials
// failed to be retrieved.
cfg.Credentials = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{Err: awserr.New("EnvAccessKeyNotFound", "failed to find credentials in the environment.", nil)},
&credProviderError{Err: awserr.New("SharedCredsLoad", fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil)},
defaults.RemoteCredProvider(*cfg, handlers),
},
}) })
// Regional Endpoint flag for S3 endpoint resolving
mergeS3UsEast1RegionalEndpointConfig(cfg, []endpoints.S3UsEast1RegionalEndpoint{
userCfg.S3UsEast1RegionalEndpoint,
envCfg.S3UsEast1RegionalEndpoint,
sharedCfg.S3UsEast1RegionalEndpoint,
endpoints.LegacyS3UsEast1Endpoint,
})
var ec2IMDSEndpoint string
for _, v := range []string{
sessOpts.EC2IMDSEndpoint,
envCfg.EC2IMDSEndpoint,
sharedCfg.EC2IMDSEndpoint,
} {
if len(v) != 0 {
ec2IMDSEndpoint = v
break
}
}
var endpointMode endpoints.EC2IMDSEndpointModeState
for _, v := range []endpoints.EC2IMDSEndpointModeState{
sessOpts.EC2IMDSEndpointMode,
envCfg.EC2IMDSEndpointMode,
sharedCfg.EC2IMDSEndpointMode,
} {
if v != endpoints.EC2IMDSEndpointModeStateUnset {
endpointMode = v
break
}
}
if len(ec2IMDSEndpoint) != 0 || endpointMode != endpoints.EC2IMDSEndpointModeStateUnset {
cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, ec2IMDSEndpoint, endpointMode)
}
// Configure credentials if not already set by the user when creating the
// Session.
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
creds, err := resolveCredentials(cfg, envCfg, sharedCfg, handlers, sessOpts)
if err != nil {
return err
}
cfg.Credentials = creds
}
cfg.S3UseARNRegion = userCfg.S3UseARNRegion
if cfg.S3UseARNRegion == nil {
cfg.S3UseARNRegion = &envCfg.S3UseARNRegion
}
if cfg.S3UseARNRegion == nil {
cfg.S3UseARNRegion = &sharedCfg.S3UseARNRegion
}
for _, v := range []endpoints.DualStackEndpointState{userCfg.UseDualStackEndpoint, envCfg.UseDualStackEndpoint, sharedCfg.UseDualStackEndpoint} {
if v != endpoints.DualStackEndpointStateUnset {
cfg.UseDualStackEndpoint = v
break
}
}
for _, v := range []endpoints.FIPSEndpointState{userCfg.UseFIPSEndpoint, envCfg.UseFIPSEndpoint, sharedCfg.UseFIPSEndpoint} {
if v != endpoints.FIPSEndpointStateUnset {
cfg.UseFIPSEndpoint = v
break
} }
} }
return nil return nil
} }
// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the func mergeSTSRegionalEndpointConfig(cfg *aws.Config, values []endpoints.STSRegionalEndpoint) {
// MFAToken option is not set when shared config is configured load assume a for _, v := range values {
// role with an MFA token. if v != endpoints.UnsetSTSEndpoint {
type AssumeRoleTokenProviderNotSetError struct{} cfg.STSRegionalEndpoint = v
break
// Code is the short id of the error. }
func (e AssumeRoleTokenProviderNotSetError) Code() string { }
return "AssumeRoleTokenProviderNotSetError"
} }
// Message is the description of the error func mergeS3UsEast1RegionalEndpointConfig(cfg *aws.Config, values []endpoints.S3UsEast1RegionalEndpoint) {
func (e AssumeRoleTokenProviderNotSetError) Message() string { for _, v := range values {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.") if v != endpoints.UnsetS3UsEast1Endpoint {
} cfg.S3UsEast1RegionalEndpoint = v
break
// OrigErr is the underlying error that caused the failure. }
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error { }
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
var emptyCreds = credentials.Value{}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
} }
func initHandlers(s *Session) { func initHandlers(s *Session) {
@ -539,7 +841,7 @@ func initHandlers(s *Session) {
} }
} }
// Copy creates and returns a copy of the current Session, coping the config // Copy creates and returns a copy of the current Session, copying the config
// and handlers. If any additional configs are provided they will be merged // and handlers. If any additional configs are provided they will be merged
// on top of the Session's copied config. // on top of the Session's copied config.
// //
@ -549,6 +851,7 @@ func (s *Session) Copy(cfgs ...*aws.Config) *Session {
newSession := &Session{ newSession := &Session{
Config: s.Config.Copy(cfgs...), Config: s.Config.Copy(cfgs...),
Handlers: s.Handlers.Copy(), Handlers: s.Handlers.Copy(),
options: s.options,
} }
initHandlers(newSession) initHandlers(newSession)
@ -559,47 +862,82 @@ func (s *Session) Copy(cfgs ...*aws.Config) *Session {
// ClientConfig satisfies the client.ConfigProvider interface and is used to // ClientConfig satisfies the client.ConfigProvider interface and is used to
// configure the service client instances. Passing the Session to the service // configure the service client instances. Passing the Session to the service
// client's constructor (New) will use this method to configure the client. // client's constructor (New) will use this method to configure the client.
func (s *Session) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config { func (s *Session) ClientConfig(service string, cfgs ...*aws.Config) client.Config {
// Backwards compatibility, the error will be eaten if user calls ClientConfig
// directly. All SDK services will use ClientconfigWithError.
cfg, _ := s.clientConfigWithErr(serviceName, cfgs...)
return cfg
}
func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (client.Config, error) {
s = s.Copy(cfgs...) s = s.Copy(cfgs...)
var resolved endpoints.ResolvedEndpoint resolvedRegion := normalizeRegion(s.Config)
var err error
region := aws.StringValue(s.Config.Region) region := aws.StringValue(s.Config.Region)
resolved, err := s.resolveEndpoint(service, region, resolvedRegion, s.Config)
if endpoint := aws.StringValue(s.Config.Endpoint); len(endpoint) != 0 { if err != nil {
resolved.URL = endpoints.AddScheme(endpoint, aws.BoolValue(s.Config.DisableSSL)) s.Handlers.Validate.PushBack(func(r *request.Request) {
resolved.SigningRegion = region if len(r.ClientInfo.Endpoint) != 0 {
} else { // Error occurred while resolving endpoint, but the request
resolved, err = s.Config.EndpointResolver.EndpointFor( // being invoked has had an endpoint specified after the client
serviceName, region, // was created.
func(opt *endpoints.Options) { return
opt.DisableSSL = aws.BoolValue(s.Config.DisableSSL) }
opt.UseDualStack = aws.BoolValue(s.Config.UseDualStack) r.Error = err
})
// Support the condition where the service is modeled but its
// endpoint metadata is not available.
opt.ResolveUnknownService = true
},
)
} }
return client.Config{ return client.Config{
Config: s.Config, Config: s.Config,
Handlers: s.Handlers, Handlers: s.Handlers,
PartitionID: resolved.PartitionID,
Endpoint: resolved.URL, Endpoint: resolved.URL,
SigningRegion: resolved.SigningRegion, SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived, SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName, SigningName: resolved.SigningName,
}, err ResolvedRegion: resolvedRegion,
}
}
const ec2MetadataServiceID = "ec2metadata"
func (s *Session) resolveEndpoint(service, region, resolvedRegion string, cfg *aws.Config) (endpoints.ResolvedEndpoint, error) {
if ep := aws.StringValue(cfg.Endpoint); len(ep) != 0 {
return endpoints.ResolvedEndpoint{
URL: endpoints.AddScheme(ep, aws.BoolValue(cfg.DisableSSL)),
SigningRegion: region,
}, nil
}
resolved, err := cfg.EndpointResolver.EndpointFor(service, region,
func(opt *endpoints.Options) {
opt.DisableSSL = aws.BoolValue(cfg.DisableSSL)
opt.UseDualStack = aws.BoolValue(cfg.UseDualStack)
opt.UseDualStackEndpoint = cfg.UseDualStackEndpoint
opt.UseFIPSEndpoint = cfg.UseFIPSEndpoint
// Support for STSRegionalEndpoint where the STSRegionalEndpoint is
// provided in envConfig or sharedConfig with envConfig getting
// precedence.
opt.STSRegionalEndpoint = cfg.STSRegionalEndpoint
// Support for S3UsEast1RegionalEndpoint where the S3UsEast1RegionalEndpoint is
// provided in envConfig or sharedConfig with envConfig getting
// precedence.
opt.S3UsEast1RegionalEndpoint = cfg.S3UsEast1RegionalEndpoint
// Support the condition where the service is modeled but its
// endpoint metadata is not available.
opt.ResolveUnknownService = true
opt.ResolvedRegion = resolvedRegion
opt.Logger = cfg.Logger
opt.LogDeprecated = cfg.LogLevel.Matches(aws.LogDebugWithDeprecated)
},
)
if err != nil {
return endpoints.ResolvedEndpoint{}, err
}
return resolved, nil
} }
// ClientConfigNoResolveEndpoint is the same as ClientConfig with the exception // ClientConfigNoResolveEndpoint is the same as ClientConfig with the exception
@ -608,13 +946,12 @@ func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (
func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Config { func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Config {
s = s.Copy(cfgs...) s = s.Copy(cfgs...)
resolvedRegion := normalizeRegion(s.Config)
var resolved endpoints.ResolvedEndpoint var resolved endpoints.ResolvedEndpoint
region := aws.StringValue(s.Config.Region)
if ep := aws.StringValue(s.Config.Endpoint); len(ep) > 0 { if ep := aws.StringValue(s.Config.Endpoint); len(ep) > 0 {
resolved.URL = endpoints.AddScheme(ep, aws.BoolValue(s.Config.DisableSSL)) resolved.URL = endpoints.AddScheme(ep, aws.BoolValue(s.Config.DisableSSL))
resolved.SigningRegion = region resolved.SigningRegion = aws.StringValue(s.Config.Region)
} }
return client.Config{ return client.Config{
@ -624,5 +961,37 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
SigningRegion: resolved.SigningRegion, SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived, SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName, SigningName: resolved.SigningName,
ResolvedRegion: resolvedRegion,
} }
} }
// logDeprecatedNewSessionError function enables error handling for session
func (s *Session) logDeprecatedNewSessionError(msg string, err error, cfgs []*aws.Config) {
// Session creation failed, need to report the error and prevent
// any requests from succeeding.
s.Config.MergeIn(cfgs...)
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
s.Handlers.Validate.PushBack(func(r *request.Request) {
r.Error = err
})
}
// normalizeRegion resolves / normalizes the configured region (converts pseudo fips regions), and modifies the provided
// config to have the equivalent options for resolution and returns the resolved region name.
func normalizeRegion(cfg *aws.Config) (resolved string) {
const fipsInfix = "-fips-"
const fipsPrefix = "-fips"
const fipsSuffix = "fips-"
region := aws.StringValue(cfg.Region)
if strings.Contains(region, fipsInfix) ||
strings.Contains(region, fipsPrefix) ||
strings.Contains(region, fipsSuffix) {
resolved = strings.Replace(strings.Replace(strings.Replace(
region, fipsInfix, "-", -1), fipsPrefix, "", -1), fipsSuffix, "", -1)
cfg.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled
}
return resolved
}

View file

@ -2,11 +2,13 @@ package session
import ( import (
"fmt" "fmt"
"io/ioutil" "strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/go-ini/ini" "github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/internal/ini"
) )
const ( const (
@ -17,67 +19,185 @@ const (
// Assume Role Credentials group // Assume Role Credentials group
roleArnKey = `role_arn` // group required roleArnKey = `role_arn` // group required
sourceProfileKey = `source_profile` // group required sourceProfileKey = `source_profile` // group required (or credential_source)
credentialSourceKey = `credential_source` // group required (or source_profile)
externalIDKey = `external_id` // optional externalIDKey = `external_id` // optional
mfaSerialKey = `mfa_serial` // optional mfaSerialKey = `mfa_serial` // optional
roleSessionNameKey = `role_session_name` // optional roleSessionNameKey = `role_session_name` // optional
roleDurationSecondsKey = "duration_seconds" // optional
// AWS Single Sign-On (AWS SSO) group
ssoAccountIDKey = "sso_account_id"
ssoRegionKey = "sso_region"
ssoRoleNameKey = "sso_role_name"
ssoStartURL = "sso_start_url"
// CSM options
csmEnabledKey = `csm_enabled`
csmHostKey = `csm_host`
csmPortKey = `csm_port`
csmClientIDKey = `csm_client_id`
// Additional Config fields // Additional Config fields
regionKey = `region` regionKey = `region`
// custom CA Bundle filename
customCABundleKey = `ca_bundle`
// endpoint discovery group
enableEndpointDiscoveryKey = `endpoint_discovery_enabled` // optional
// External Credential Process
credentialProcessKey = `credential_process` // optional
// Web Identity Token File
webIdentityTokenFileKey = `web_identity_token_file` // optional
// Additional config fields for regional or legacy endpoints
stsRegionalEndpointSharedKey = `sts_regional_endpoints`
// Additional config fields for regional or legacy endpoints
s3UsEast1RegionalSharedKey = `s3_us_east_1_regional_endpoint`
// DefaultSharedConfigProfile is the default profile to be used when // DefaultSharedConfigProfile is the default profile to be used when
// loading configuration from the config files if another profile name // loading configuration from the config files if another profile name
// is not provided. // is not provided.
DefaultSharedConfigProfile = `default` DefaultSharedConfigProfile = `default`
)
type assumeRoleConfig struct { // S3 ARN Region Usage
RoleARN string s3UseARNRegionKey = "s3_use_arn_region"
SourceProfile string
ExternalID string // EC2 IMDS Endpoint Mode
MFASerial string ec2MetadataServiceEndpointModeKey = "ec2_metadata_service_endpoint_mode"
RoleSessionName string
} // EC2 IMDS Endpoint
ec2MetadataServiceEndpointKey = "ec2_metadata_service_endpoint"
// Use DualStack Endpoint Resolution
useDualStackEndpoint = "use_dualstack_endpoint"
// Use FIPS Endpoint Resolution
useFIPSEndpointKey = "use_fips_endpoint"
)
// sharedConfig represents the configuration fields of the SDK config files. // sharedConfig represents the configuration fields of the SDK config files.
type sharedConfig struct { type sharedConfig struct {
// Credentials values from the config file. Both aws_access_key_id Profile string
// and aws_secret_access_key must be provided together in the same file
// to be considered valid. The values will be ignored if not a complete group. // Credentials values from the config file. Both aws_access_key_id and
// aws_session_token is an optional field that can be provided if both of the // aws_secret_access_key must be provided together in the same file to be
// other two fields are also provided. // considered valid. The values will be ignored if not a complete group.
// aws_session_token is an optional field that can be provided if both of
// the other two fields are also provided.
// //
// aws_access_key_id // aws_access_key_id
// aws_secret_access_key // aws_secret_access_key
// aws_session_token // aws_session_token
Creds credentials.Value Creds credentials.Value
AssumeRole assumeRoleConfig CredentialSource string
AssumeRoleSource *sharedConfig CredentialProcess string
WebIdentityTokenFile string
// Region is the region the SDK should use for looking up AWS service endpoints SSOAccountID string
// and signing requests. SSORegion string
SSORoleName string
SSOStartURL string
RoleARN string
RoleSessionName string
ExternalID string
MFASerial string
AssumeRoleDuration *time.Duration
SourceProfileName string
SourceProfile *sharedConfig
// Region is the region the SDK should use for looking up AWS service
// endpoints and signing requests.
// //
// region // region
Region string Region string
// CustomCABundle is the file path to a PEM file the SDK will read and
// use to configure the HTTP transport with additional CA certs that are
// not present in the platforms default CA store.
//
// This value will be ignored if the file does not exist.
//
// ca_bundle
CustomCABundle string
// EnableEndpointDiscovery can be enabled in the shared config by setting
// endpoint_discovery_enabled to true
//
// endpoint_discovery_enabled = true
EnableEndpointDiscovery *bool
// CSM Options
CSMEnabled *bool
CSMHost string
CSMPort string
CSMClientID string
// Specifies the Regional Endpoint flag for the SDK to resolve the endpoint for a service
//
// sts_regional_endpoints = regional
// This can take value as `LegacySTSEndpoint` or `RegionalSTSEndpoint`
STSRegionalEndpoint endpoints.STSRegionalEndpoint
// Specifies the Regional Endpoint flag for the SDK to resolve the endpoint for a service
//
// s3_us_east_1_regional_endpoint = regional
// This can take value as `LegacyS3UsEast1Endpoint` or `RegionalS3UsEast1Endpoint`
S3UsEast1RegionalEndpoint endpoints.S3UsEast1RegionalEndpoint
// Specifies if the S3 service should allow ARNs to direct the region
// the client's requests are sent to.
//
// s3_use_arn_region=true
S3UseARNRegion bool
// Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6)
//
// ec2_metadata_service_endpoint_mode=IPv6
EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState
// Specifies the EC2 Instance Metadata Service endpoint to use. If specified it overrides EC2IMDSEndpointMode.
//
// ec2_metadata_service_endpoint=http://fd00:ec2::254
EC2IMDSEndpoint string
// Specifies that SDK clients must resolve a dual-stack endpoint for
// services.
//
// use_dualstack_endpoint=true
UseDualStackEndpoint endpoints.DualStackEndpointState
// Specifies that SDK clients must resolve a FIPS endpoint for
// services.
//
// use_fips_endpoint=true
UseFIPSEndpoint endpoints.FIPSEndpointState
} }
type sharedConfigFile struct { type sharedConfigFile struct {
Filename string Filename string
IniData *ini.File IniData ini.Sections
} }
// loadSharedConfig retrieves the configuration from the list of files // loadSharedConfig retrieves the configuration from the list of files using
// using the profile provided. The order the files are listed will determine // the profile provided. The order the files are listed will determine
// precedence. Values in subsequent files will overwrite values defined in // precedence. Values in subsequent files will overwrite values defined in
// earlier files. // earlier files.
// //
// For example, given two files A and B. Both define credentials. If the order // For example, given two files A and B. Both define credentials. If the order
// of the files are A then B, B's credential values will be used instead of A's. // of the files are A then B, B's credential values will be used instead of
// A's.
// //
// See sharedConfig.setFromFile for information how the config files // See sharedConfig.setFromFile for information how the config files
// will be loaded. // will be loaded.
func loadSharedConfig(profile string, filenames []string) (sharedConfig, error) { func loadSharedConfig(profile string, filenames []string, exOpts bool) (sharedConfig, error) {
if len(profile) == 0 { if len(profile) == 0 {
profile = DefaultSharedConfigProfile profile = DefaultSharedConfigProfile
} }
@ -88,16 +208,11 @@ func loadSharedConfig(profile string, filenames []string) (sharedConfig, error)
} }
cfg := sharedConfig{} cfg := sharedConfig{}
if err = cfg.setFromIniFiles(profile, files); err != nil { profiles := map[string]struct{}{}
if err = cfg.setFromIniFiles(profiles, profile, files, exOpts); err != nil {
return sharedConfig{}, err return sharedConfig{}, err
} }
if len(cfg.AssumeRole.SourceProfile) > 0 {
if err := cfg.setAssumeRoleSource(profile, files); err != nil {
return sharedConfig{}, err
}
}
return cfg, nil return cfg, nil
} }
@ -105,114 +220,365 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
files := make([]sharedConfigFile, 0, len(filenames)) files := make([]sharedConfigFile, 0, len(filenames))
for _, filename := range filenames { for _, filename := range filenames {
b, err := ioutil.ReadFile(filename) sections, err := ini.OpenFile(filename)
if err != nil { if aerr, ok := err.(awserr.Error); ok && aerr.Code() == ini.ErrCodeUnableToReadFile {
// Skip files which can't be opened and read for whatever reason // Skip files which can't be opened and read for whatever reason
continue continue
} } else if err != nil {
f, err := ini.Load(b)
if err != nil {
return nil, SharedConfigLoadError{Filename: filename, Err: err} return nil, SharedConfigLoadError{Filename: filename, Err: err}
} }
files = append(files, sharedConfigFile{ files = append(files, sharedConfigFile{
Filename: filename, IniData: f, Filename: filename, IniData: sections,
}) })
} }
return files, nil return files, nil
} }
func (cfg *sharedConfig) setAssumeRoleSource(origProfile string, files []sharedConfigFile) error { func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile string, files []sharedConfigFile, exOpts bool) error {
var assumeRoleSrc sharedConfig cfg.Profile = profile
// Multiple level assume role chains are not support
if cfg.AssumeRole.SourceProfile == origProfile {
assumeRoleSrc = *cfg
assumeRoleSrc.AssumeRole = assumeRoleConfig{}
} else {
err := assumeRoleSrc.setFromIniFiles(cfg.AssumeRole.SourceProfile, files)
if err != nil {
return err
}
}
if len(assumeRoleSrc.Creds.AccessKeyID) == 0 {
return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN}
}
cfg.AssumeRoleSource = &assumeRoleSrc
return nil
}
func (cfg *sharedConfig) setFromIniFiles(profile string, files []sharedConfigFile) error {
// Trim files from the list that don't exist. // Trim files from the list that don't exist.
var skippedFiles int
var profileNotFoundErr error
for _, f := range files { for _, f := range files {
if err := cfg.setFromIniFile(profile, f); err != nil { if err := cfg.setFromIniFile(profile, f, exOpts); err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); ok { if _, ok := err.(SharedConfigProfileNotExistsError); ok {
// Ignore proviles missings // Ignore profiles not defined in individual files.
profileNotFoundErr = err
skippedFiles++
continue continue
} }
return err return err
} }
} }
if skippedFiles == len(files) {
// If all files were skipped because the profile is not found, return
// the original profile not found error.
return profileNotFoundErr
}
if _, ok := profiles[profile]; ok {
// if this is the second instance of the profile the Assume Role
// options must be cleared because they are only valid for the
// first reference of a profile. The self linked instance of the
// profile only have credential provider options.
cfg.clearAssumeRoleOptions()
} else {
// First time a profile has been seen, It must either be a assume role
// credentials, or SSO. Assert if the credential type requires a role ARN,
// the ARN is also set, or validate that the SSO configuration is complete.
if err := cfg.validateCredentialsConfig(profile); err != nil {
return err
}
}
profiles[profile] = struct{}{}
if err := cfg.validateCredentialType(); err != nil {
return err
}
// Link source profiles for assume roles
if len(cfg.SourceProfileName) != 0 {
// Linked profile via source_profile ignore credential provider
// options, the source profile must provide the credentials.
cfg.clearCredentialOptions()
srcCfg := &sharedConfig{}
err := srcCfg.setFromIniFiles(profiles, cfg.SourceProfileName, files, exOpts)
if err != nil {
// SourceProfile that doesn't exist is an error in configuration.
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
err = SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
return err
}
if !srcCfg.hasCredentials() {
return SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
cfg.SourceProfile = srcCfg
}
return nil return nil
} }
// setFromFile loads the configuration from the file using // setFromFile loads the configuration from the file using the profile
// the profile provided. A sharedConfig pointer type value is used so that // provided. A sharedConfig pointer type value is used so that multiple config
// multiple config file loadings can be chained. // file loadings can be chained.
// //
// Only loads complete logically grouped values, and will not set fields in cfg // Only loads complete logically grouped values, and will not set fields in cfg
// for incomplete grouped values in the config. Such as credentials. For example // for incomplete grouped values in the config. Such as credentials. For
// if a config file only includes aws_access_key_id but no aws_secret_access_key // example if a config file only includes aws_access_key_id but no
// the aws_access_key_id will be ignored. // aws_secret_access_key the aws_access_key_id will be ignored.
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) error { func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, exOpts bool) error {
section, err := file.IniData.GetSection(profile) section, ok := file.IniData.GetSection(profile)
if err != nil { if !ok {
// Fallback to to alternate profile name: profile <name> // Fallback to to alternate profile name: profile <name>
section, err = file.IniData.GetSection(fmt.Sprintf("profile %s", profile)) section, ok = file.IniData.GetSection(fmt.Sprintf("profile %s", profile))
if !ok {
return SharedConfigProfileNotExistsError{Profile: profile, Err: nil}
}
}
if exOpts {
// Assume Role Parameters
updateString(&cfg.RoleARN, section, roleArnKey)
updateString(&cfg.ExternalID, section, externalIDKey)
updateString(&cfg.MFASerial, section, mfaSerialKey)
updateString(&cfg.RoleSessionName, section, roleSessionNameKey)
updateString(&cfg.SourceProfileName, section, sourceProfileKey)
updateString(&cfg.CredentialSource, section, credentialSourceKey)
updateString(&cfg.Region, section, regionKey)
updateString(&cfg.CustomCABundle, section, customCABundleKey)
if section.Has(roleDurationSecondsKey) {
d := time.Duration(section.Int(roleDurationSecondsKey)) * time.Second
cfg.AssumeRoleDuration = &d
}
if v := section.String(stsRegionalEndpointSharedKey); len(v) != 0 {
sre, err := endpoints.GetSTSRegionalEndpoint(v)
if err != nil { if err != nil {
return SharedConfigProfileNotExistsError{Profile: profile, Err: err} return fmt.Errorf("failed to load %s from shared config, %s, %v",
stsRegionalEndpointSharedKey, file.Filename, err)
} }
cfg.STSRegionalEndpoint = sre
} }
if v := section.String(s3UsEast1RegionalSharedKey); len(v) != 0 {
sre, err := endpoints.GetS3UsEast1RegionalEndpoint(v)
if err != nil {
return fmt.Errorf("failed to load %s from shared config, %s, %v",
s3UsEast1RegionalSharedKey, file.Filename, err)
}
cfg.S3UsEast1RegionalEndpoint = sre
}
// AWS Single Sign-On (AWS SSO)
updateString(&cfg.SSOAccountID, section, ssoAccountIDKey)
updateString(&cfg.SSORegion, section, ssoRegionKey)
updateString(&cfg.SSORoleName, section, ssoRoleNameKey)
updateString(&cfg.SSOStartURL, section, ssoStartURL)
if err := updateEC2MetadataServiceEndpointMode(&cfg.EC2IMDSEndpointMode, section, ec2MetadataServiceEndpointModeKey); err != nil {
return fmt.Errorf("failed to load %s from shared config, %s, %v",
ec2MetadataServiceEndpointModeKey, file.Filename, err)
}
updateString(&cfg.EC2IMDSEndpoint, section, ec2MetadataServiceEndpointKey)
updateUseDualStackEndpoint(&cfg.UseDualStackEndpoint, section, useDualStackEndpoint)
updateUseFIPSEndpoint(&cfg.UseFIPSEndpoint, section, useFIPSEndpointKey)
}
updateString(&cfg.CredentialProcess, section, credentialProcessKey)
updateString(&cfg.WebIdentityTokenFile, section, webIdentityTokenFileKey)
// Shared Credentials // Shared Credentials
akid := section.Key(accessKeyIDKey).String() creds := credentials.Value{
secret := section.Key(secretAccessKey).String() AccessKeyID: section.String(accessKeyIDKey),
if len(akid) > 0 && len(secret) > 0 { SecretAccessKey: section.String(secretAccessKey),
cfg.Creds = credentials.Value{ SessionToken: section.String(sessionTokenKey),
AccessKeyID: akid,
SecretAccessKey: secret,
SessionToken: section.Key(sessionTokenKey).String(),
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename), ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
} }
if creds.HasKeys() {
cfg.Creds = creds
} }
// Assume Role // Endpoint discovery
roleArn := section.Key(roleArnKey).String() updateBoolPtr(&cfg.EnableEndpointDiscovery, section, enableEndpointDiscoveryKey)
srcProfile := section.Key(sourceProfileKey).String()
if len(roleArn) > 0 && len(srcProfile) > 0 {
cfg.AssumeRole = assumeRoleConfig{
RoleARN: roleArn,
SourceProfile: srcProfile,
ExternalID: section.Key(externalIDKey).String(),
MFASerial: section.Key(mfaSerialKey).String(),
RoleSessionName: section.Key(roleSessionNameKey).String(),
}
}
// Region // CSM options
if v := section.Key(regionKey).String(); len(v) > 0 { updateBoolPtr(&cfg.CSMEnabled, section, csmEnabledKey)
cfg.Region = v updateString(&cfg.CSMHost, section, csmHostKey)
updateString(&cfg.CSMPort, section, csmPortKey)
updateString(&cfg.CSMClientID, section, csmClientIDKey)
updateBool(&cfg.S3UseARNRegion, section, s3UseARNRegionKey)
return nil
}
func updateEC2MetadataServiceEndpointMode(endpointMode *endpoints.EC2IMDSEndpointModeState, section ini.Section, key string) error {
if !section.Has(key) {
return nil
}
value := section.String(key)
return endpointMode.SetFromString(value)
}
func (cfg *sharedConfig) validateCredentialsConfig(profile string) error {
if err := cfg.validateCredentialsRequireARN(profile); err != nil {
return err
} }
return nil return nil
} }
func (cfg *sharedConfig) validateCredentialsRequireARN(profile string) error {
var credSource string
switch {
case len(cfg.SourceProfileName) != 0:
credSource = sourceProfileKey
case len(cfg.CredentialSource) != 0:
credSource = credentialSourceKey
case len(cfg.WebIdentityTokenFile) != 0:
credSource = webIdentityTokenFileKey
}
if len(credSource) != 0 && len(cfg.RoleARN) == 0 {
return CredentialRequiresARNError{
Type: credSource,
Profile: profile,
}
}
return nil
}
func (cfg *sharedConfig) validateCredentialType() error {
// Only one or no credential type can be defined.
if !oneOrNone(
len(cfg.SourceProfileName) != 0,
len(cfg.CredentialSource) != 0,
len(cfg.CredentialProcess) != 0,
len(cfg.WebIdentityTokenFile) != 0,
) {
return ErrSharedConfigSourceCollision
}
return nil
}
func (cfg *sharedConfig) validateSSOConfiguration() error {
if !cfg.hasSSOConfiguration() {
return nil
}
var missing []string
if len(cfg.SSOAccountID) == 0 {
missing = append(missing, ssoAccountIDKey)
}
if len(cfg.SSORegion) == 0 {
missing = append(missing, ssoRegionKey)
}
if len(cfg.SSORoleName) == 0 {
missing = append(missing, ssoRoleNameKey)
}
if len(cfg.SSOStartURL) == 0 {
missing = append(missing, ssoStartURL)
}
if len(missing) > 0 {
return fmt.Errorf("profile %q is configured to use SSO but is missing required configuration: %s",
cfg.Profile, strings.Join(missing, ", "))
}
return nil
}
func (cfg *sharedConfig) hasCredentials() bool {
switch {
case len(cfg.SourceProfileName) != 0:
case len(cfg.CredentialSource) != 0:
case len(cfg.CredentialProcess) != 0:
case len(cfg.WebIdentityTokenFile) != 0:
case cfg.hasSSOConfiguration():
case cfg.Creds.HasKeys():
default:
return false
}
return true
}
func (cfg *sharedConfig) clearCredentialOptions() {
cfg.CredentialSource = ""
cfg.CredentialProcess = ""
cfg.WebIdentityTokenFile = ""
cfg.Creds = credentials.Value{}
cfg.SSOAccountID = ""
cfg.SSORegion = ""
cfg.SSORoleName = ""
cfg.SSOStartURL = ""
}
func (cfg *sharedConfig) clearAssumeRoleOptions() {
cfg.RoleARN = ""
cfg.ExternalID = ""
cfg.MFASerial = ""
cfg.RoleSessionName = ""
cfg.SourceProfileName = ""
}
func (cfg *sharedConfig) hasSSOConfiguration() bool {
switch {
case len(cfg.SSOAccountID) != 0:
case len(cfg.SSORegion) != 0:
case len(cfg.SSORoleName) != 0:
case len(cfg.SSOStartURL) != 0:
default:
return false
}
return true
}
func oneOrNone(bs ...bool) bool {
var count int
for _, b := range bs {
if b {
count++
if count > 1 {
return false
}
}
}
return true
}
// updateString will only update the dst with the value in the section key, key
// is present in the section.
func updateString(dst *string, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = section.String(key)
}
// updateBool will only update the dst with the value in the section key, key
// is present in the section.
func updateBool(dst *bool, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = section.Bool(key)
}
// updateBoolPtr will only update the dst with the value in the section key,
// key is present in the section.
func updateBoolPtr(dst **bool, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = new(bool)
**dst = section.Bool(key)
}
// SharedConfigLoadError is an error for the shared config file failed to load. // SharedConfigLoadError is an error for the shared config file failed to load.
type SharedConfigLoadError struct { type SharedConfigLoadError struct {
Filename string Filename string
@ -271,6 +637,7 @@ func (e SharedConfigProfileNotExistsError) Error() string {
// or not complete. // or not complete.
type SharedConfigAssumeRoleError struct { type SharedConfigAssumeRoleError struct {
RoleARN string RoleARN string
SourceProfile string
} }
// Code is the short id of the error. // Code is the short id of the error.
@ -280,8 +647,10 @@ func (e SharedConfigAssumeRoleError) Code() string {
// Message is the description of the error // Message is the description of the error
func (e SharedConfigAssumeRoleError) Message() string { func (e SharedConfigAssumeRoleError) Message() string {
return fmt.Sprintf("failed to load assume role for %s, source profile has no shared credentials", return fmt.Sprintf(
e.RoleARN) "failed to load assume role for %s, source profile %s has no shared credentials",
e.RoleARN, e.SourceProfile,
)
} }
// OrigErr is the underlying error that caused the failure. // OrigErr is the underlying error that caused the failure.
@ -293,3 +662,68 @@ func (e SharedConfigAssumeRoleError) OrigErr() error {
func (e SharedConfigAssumeRoleError) Error() string { func (e SharedConfigAssumeRoleError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil) return awserr.SprintError(e.Code(), e.Message(), "", nil)
} }
// CredentialRequiresARNError provides the error for shared config credentials
// that are incorrectly configured in the shared config or credentials file.
type CredentialRequiresARNError struct {
// type of credentials that were configured.
Type string
// Profile name the credentials were in.
Profile string
}
// Code is the short id of the error.
func (e CredentialRequiresARNError) Code() string {
return "CredentialRequiresARNError"
}
// Message is the description of the error
func (e CredentialRequiresARNError) Message() string {
return fmt.Sprintf(
"credential type %s requires role_arn, profile %s",
e.Type, e.Profile,
)
}
// OrigErr is the underlying error that caused the failure.
func (e CredentialRequiresARNError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e CredentialRequiresARNError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
// updateEndpointDiscoveryType will only update the dst with the value in the section, if
// a valid key and corresponding EndpointDiscoveryType is found.
func updateUseDualStackEndpoint(dst *endpoints.DualStackEndpointState, section ini.Section, key string) {
if !section.Has(key) {
return
}
if section.Bool(key) {
*dst = endpoints.DualStackEndpointStateEnabled
} else {
*dst = endpoints.DualStackEndpointStateDisabled
}
return
}
// updateEndpointDiscoveryType will only update the dst with the value in the section, if
// a valid key and corresponding EndpointDiscoveryType is found.
func updateUseFIPSEndpoint(dst *endpoints.FIPSEndpointState, section ini.Section, key string) {
if !section.Has(key) {
return
}
if section.Bool(key) {
*dst = endpoints.FIPSEndpointStateEnabled
} else {
*dst = endpoints.FIPSEndpointStateDisabled
}
return
}

View file

@ -1,8 +1,7 @@
package v4 package v4
import ( import (
"net/http" "github.com/aws/aws-sdk-go/internal/strings"
"strings"
) )
// validator houses a set of rule needed for validation of a // validator houses a set of rule needed for validation of a
@ -35,23 +34,23 @@ func (m mapRule) IsValid(value string) bool {
return ok return ok
} }
// whitelist is a generic rule for whitelisting // allowList is a generic rule for allow listing
type whitelist struct { type allowList struct {
rule rule
} }
// IsValid for whitelist checks if the value is within the whitelist // IsValid for allow list checks if the value is within the allow list
func (w whitelist) IsValid(value string) bool { func (w allowList) IsValid(value string) bool {
return w.rule.IsValid(value) return w.rule.IsValid(value)
} }
// blacklist is a generic rule for blacklisting // excludeList is a generic rule for exclude listing
type blacklist struct { type excludeList struct {
rule rule
} }
// IsValid for whitelist checks if the value is within the whitelist // IsValid for exclude list checks if the value is within the exclude list
func (b blacklist) IsValid(value string) bool { func (b excludeList) IsValid(value string) bool {
return !b.rule.IsValid(value) return !b.rule.IsValid(value)
} }
@ -61,7 +60,7 @@ type patterns []string
// been found // been found
func (p patterns) IsValid(value string) bool { func (p patterns) IsValid(value string) bool {
for _, pattern := range p { for _, pattern := range p {
if strings.HasPrefix(http.CanonicalHeaderKey(value), pattern) { if strings.HasPrefixFold(value, pattern) {
return true return true
} }
} }

View file

@ -0,0 +1,14 @@
//go:build !go1.7
// +build !go1.7
package v4
import (
"net/http"
"github.com/aws/aws-sdk-go/aws"
)
func requestContext(r *http.Request) aws.Context {
return aws.BackgroundContext()
}

View file

@ -0,0 +1,14 @@
//go:build go1.7
// +build go1.7
package v4
import (
"net/http"
"github.com/aws/aws-sdk-go/aws"
)
func requestContext(r *http.Request) aws.Context {
return r.Context()
}

View file

@ -0,0 +1,63 @@
package v4
import (
"encoding/hex"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
type credentialValueProvider interface {
Get() (credentials.Value, error)
}
// StreamSigner implements signing of event stream encoded payloads
type StreamSigner struct {
region string
service string
credentials credentialValueProvider
prevSig []byte
}
// NewStreamSigner creates a SigV4 signer used to sign Event Stream encoded messages
func NewStreamSigner(region, service string, seedSignature []byte, credentials *credentials.Credentials) *StreamSigner {
return &StreamSigner{
region: region,
service: service,
credentials: credentials,
prevSig: seedSignature,
}
}
// GetSignature takes an event stream encoded headers and payload and returns a signature
func (s *StreamSigner) GetSignature(headers, payload []byte, date time.Time) ([]byte, error) {
credValue, err := s.credentials.Get()
if err != nil {
return nil, err
}
sigKey := deriveSigningKey(s.region, s.service, credValue.SecretAccessKey, date)
keyPath := buildSigningScope(s.region, s.service, date)
stringToSign := buildEventStreamStringToSign(headers, payload, s.prevSig, keyPath, date)
signature := hmacSHA256(sigKey, []byte(stringToSign))
s.prevSig = signature
return signature, nil
}
func buildEventStreamStringToSign(headers, payload, prevSig []byte, scope string, date time.Time) string {
return strings.Join([]string{
"AWS4-HMAC-SHA256-PAYLOAD",
formatTime(date),
scope,
hex.EncodeToString(prevSig),
hex.EncodeToString(hashSHA256(headers)),
hex.EncodeToString(hashSHA256(payload)),
}, "\n")
}

View file

@ -1,3 +1,4 @@
//go:build go1.5
// +build go1.5 // +build go1.5
package v4 package v4

View file

@ -76,27 +76,32 @@ import (
) )
const ( const (
authorizationHeader = "Authorization"
authHeaderSignatureElem = "Signature="
signatureQueryKey = "X-Amz-Signature"
authHeaderPrefix = "AWS4-HMAC-SHA256" authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z" timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102" shortTimeFormat = "20060102"
awsV4Request = "aws4_request"
// emptyStringSHA256 is a SHA256 of an empty string // emptyStringSHA256 is a SHA256 of an empty string
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855` emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
) )
var ignoredHeaders = rules{ var ignoredHeaders = rules{
blacklist{ excludeList{
mapRule{ mapRule{
"Authorization": struct{}{}, authorizationHeader: struct{}{},
"User-Agent": struct{}{}, "User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{}, "X-Amzn-Trace-Id": struct{}{},
}, },
}, },
} }
// requiredSignedHeaders is a whitelist for build canonical headers. // requiredSignedHeaders is a allow list for build canonical headers.
var requiredSignedHeaders = rules{ var requiredSignedHeaders = rules{
whitelist{ allowList{
mapRule{ mapRule{
"Cache-Control": struct{}{}, "Cache-Control": struct{}{},
"Content-Disposition": struct{}{}, "Content-Disposition": struct{}{},
@ -134,17 +139,19 @@ var requiredSignedHeaders = rules{
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{}, "X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{}, "X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Storage-Class": struct{}{}, "X-Amz-Storage-Class": struct{}{},
"X-Amz-Tagging": struct{}{},
"X-Amz-Website-Redirect-Location": struct{}{}, "X-Amz-Website-Redirect-Location": struct{}{},
"X-Amz-Content-Sha256": struct{}{}, "X-Amz-Content-Sha256": struct{}{},
}, },
}, },
patterns{"X-Amz-Meta-"}, patterns{"X-Amz-Meta-"},
patterns{"X-Amz-Object-Lock-"},
} }
// allowedHoisting is a whitelist for build query headers. The boolean value // allowedHoisting is a allow list for build query headers. The boolean value
// represents whether or not it is a pattern. // represents whether or not it is a pattern.
var allowedQueryHoisting = inclusiveRules{ var allowedQueryHoisting = inclusiveRules{
blacklist{requiredSignedHeaders}, excludeList{requiredSignedHeaders},
patterns{"X-Amz-"}, patterns{"X-Amz-"},
} }
@ -181,7 +188,7 @@ type Signer struct {
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html // http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool DisableURIPathEscaping bool
// Disales the automatical setting of the HTTP request's Body field with the // Disables the automatical setting of the HTTP request's Body field with the
// io.ReadSeeker passed in to the signer. This is useful if you're using a // io.ReadSeeker passed in to the signer. This is useful if you're using a
// custom wrapper around the body for the io.ReadSeeker and want to preserve // custom wrapper around the body for the io.ReadSeeker and want to preserve
// the Body value on the Request.Body. // the Body value on the Request.Body.
@ -230,8 +237,6 @@ type signingCtx struct {
credValues credentials.Value credValues credentials.Value
isPresign bool isPresign bool
formattedTime string
formattedShortTime string
unsignedPayload bool unsignedPayload bool
bodyDigest string bodyDigest string
@ -336,7 +341,7 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
} }
var err error var err error
ctx.credValues, err = v4.Credentials.Get() ctx.credValues, err = v4.Credentials.GetWithContext(requestContext(r))
if err != nil { if err != nil {
return http.Header{}, err return http.Header{}, err
} }
@ -413,7 +418,7 @@ var SignRequestHandler = request.NamedHandler{
// request handler should only be used with the SDK's built in service client's // request handler should only be used with the SDK's built in service client's
// API operation requests. // API operation requests.
// //
// This function should not be used on its on its own, but in conjunction with // This function should not be used on its own, but in conjunction with
// an AWS service client's API operation call. To sign a standalone request // an AWS service client's API operation call. To sign a standalone request
// not created by a service client's API operation method use the "Sign" or // not created by a service client's API operation method use the "Sign" or
// "Presign" functions of the "Signer" type. // "Presign" functions of the "Signer" type.
@ -421,7 +426,7 @@ var SignRequestHandler = request.NamedHandler{
// If the credentials of the request's config are set to // If the credentials of the request's config are set to
// credentials.AnonymousCredentials the request will not be signed. // credentials.AnonymousCredentials the request will not be signed.
func SignSDKRequest(req *request.Request) { func SignSDKRequest(req *request.Request) {
signSDKRequestWithCurrTime(req, time.Now) SignSDKRequestWithCurrentTime(req, time.Now)
} }
// BuildNamedHandler will build a generic handler for signing. // BuildNamedHandler will build a generic handler for signing.
@ -429,12 +434,15 @@ func BuildNamedHandler(name string, opts ...func(*Signer)) request.NamedHandler
return request.NamedHandler{ return request.NamedHandler{
Name: name, Name: name,
Fn: func(req *request.Request) { Fn: func(req *request.Request) {
signSDKRequestWithCurrTime(req, time.Now, opts...) SignSDKRequestWithCurrentTime(req, time.Now, opts...)
}, },
} }
} }
func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) { // SignSDKRequestWithCurrentTime will sign the SDK's request using the time
// function passed in. Behaves the same as SignSDKRequest with the exception
// the request is signed with the value returned by the current time function.
func SignSDKRequestWithCurrentTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
// If the request does not need to be signed ignore the signing of the // If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used. // request if the AnonymousCredentials object is used.
if req.Config.Credentials == credentials.AnonymousCredentials { if req.Config.Credentials == credentials.AnonymousCredentials {
@ -470,13 +478,9 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
opt(v4) opt(v4)
} }
signingTime := req.Time curTime := curTimeFn()
if !req.LastSignedAt.IsZero() {
signingTime = req.LastSignedAt
}
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(), signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
name, region, req.ExpireTime, req.ExpireTime > 0, signingTime, name, region, req.ExpireTime, req.ExpireTime > 0, curTime,
) )
if err != nil { if err != nil {
req.Error = err req.Error = err
@ -485,7 +489,7 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
} }
req.SignedHeaderVals = signedHeaders req.SignedHeaderVals = signedHeaders
req.LastSignedAt = curTimeFn() req.LastSignedAt = curTime
} }
const logSignInfoMsg = `DEBUG: Request Signature: const logSignInfoMsg = `DEBUG: Request Signature:
@ -532,39 +536,56 @@ func (ctx *signingCtx) build(disableHeaderHoisting bool) error {
ctx.buildSignature() // depends on string to sign ctx.buildSignature() // depends on string to sign
if ctx.isPresign { if ctx.isPresign {
ctx.Request.URL.RawQuery += "&X-Amz-Signature=" + ctx.signature ctx.Request.URL.RawQuery += "&" + signatureQueryKey + "=" + ctx.signature
} else { } else {
parts := []string{ parts := []string{
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString, authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
"SignedHeaders=" + ctx.signedHeaders, "SignedHeaders=" + ctx.signedHeaders,
"Signature=" + ctx.signature, authHeaderSignatureElem + ctx.signature,
} }
ctx.Request.Header.Set("Authorization", strings.Join(parts, ", ")) ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
} }
return nil return nil
} }
func (ctx *signingCtx) buildTime() { // GetSignedRequestSignature attempts to extract the signature of the request.
ctx.formattedTime = ctx.Time.UTC().Format(timeFormat) // Returning an error if the request is unsigned, or unable to extract the
ctx.formattedShortTime = ctx.Time.UTC().Format(shortTimeFormat) // signature.
func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
ps := strings.Split(auth, ", ")
for _, p := range ps {
if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
sig := p[len(authHeaderSignatureElem):]
if len(sig) == 0 {
return nil, fmt.Errorf("invalid request signature authorization header")
}
return hex.DecodeString(sig)
}
}
}
if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
return hex.DecodeString(sig)
}
return nil, fmt.Errorf("request not signed")
}
func (ctx *signingCtx) buildTime() {
if ctx.isPresign { if ctx.isPresign {
duration := int64(ctx.ExpireTime / time.Second) duration := int64(ctx.ExpireTime / time.Second)
ctx.Query.Set("X-Amz-Date", ctx.formattedTime) ctx.Query.Set("X-Amz-Date", formatTime(ctx.Time))
ctx.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10)) ctx.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10))
} else { } else {
ctx.Request.Header.Set("X-Amz-Date", ctx.formattedTime) ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
} }
} }
func (ctx *signingCtx) buildCredentialString() { func (ctx *signingCtx) buildCredentialString() {
ctx.credentialString = strings.Join([]string{ ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
ctx.formattedShortTime,
ctx.Region,
ctx.ServiceName,
"aws4_request",
}, "/")
if ctx.isPresign { if ctx.isPresign {
ctx.Query.Set("X-Amz-Credential", ctx.credValues.AccessKeyID+"/"+ctx.credentialString) ctx.Query.Set("X-Amz-Credential", ctx.credValues.AccessKeyID+"/"+ctx.credentialString)
@ -588,8 +609,7 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
var headers []string var headers []string
headers = append(headers, "host") headers = append(headers, "host")
for k, v := range header { for k, v := range header {
canonicalKey := http.CanonicalHeaderKey(k) if !r.IsValid(k) {
if !r.IsValid(canonicalKey) {
continue // ignored header continue // ignored header
} }
if ctx.SignedHeaderVals == nil { if ctx.SignedHeaderVals == nil {
@ -614,21 +634,25 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
ctx.Query.Set("X-Amz-SignedHeaders", ctx.signedHeaders) ctx.Query.Set("X-Amz-SignedHeaders", ctx.signedHeaders)
} }
headerValues := make([]string, len(headers)) headerItems := make([]string, len(headers))
for i, k := range headers { for i, k := range headers {
if k == "host" { if k == "host" {
if ctx.Request.Host != "" { if ctx.Request.Host != "" {
headerValues[i] = "host:" + ctx.Request.Host headerItems[i] = "host:" + ctx.Request.Host
} else { } else {
headerValues[i] = "host:" + ctx.Request.URL.Host headerItems[i] = "host:" + ctx.Request.URL.Host
} }
} else { } else {
headerValues[i] = k + ":" + headerValues := make([]string, len(ctx.SignedHeaderVals[k]))
strings.Join(ctx.SignedHeaderVals[k], ",") for i, v := range ctx.SignedHeaderVals[k] {
headerValues[i] = strings.TrimSpace(v)
}
headerItems[i] = k + ":" +
strings.Join(headerValues, ",")
} }
} }
stripExcessSpaces(headerValues) stripExcessSpaces(headerItems)
ctx.canonicalHeaders = strings.Join(headerValues, "\n") ctx.canonicalHeaders = strings.Join(headerItems, "\n")
} }
func (ctx *signingCtx) buildCanonicalString() { func (ctx *signingCtx) buildCanonicalString() {
@ -653,19 +677,15 @@ func (ctx *signingCtx) buildCanonicalString() {
func (ctx *signingCtx) buildStringToSign() { func (ctx *signingCtx) buildStringToSign() {
ctx.stringToSign = strings.Join([]string{ ctx.stringToSign = strings.Join([]string{
authHeaderPrefix, authHeaderPrefix,
ctx.formattedTime, formatTime(ctx.Time),
ctx.credentialString, ctx.credentialString,
hex.EncodeToString(makeSha256([]byte(ctx.canonicalString))), hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
}, "\n") }, "\n")
} }
func (ctx *signingCtx) buildSignature() { func (ctx *signingCtx) buildSignature() {
secret := ctx.credValues.SecretAccessKey creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
date := makeHmac([]byte("AWS4"+secret), []byte(ctx.formattedShortTime)) signature := hmacSHA256(creds, []byte(ctx.stringToSign))
region := makeHmac(date, []byte(ctx.Region))
service := makeHmac(region, []byte(ctx.ServiceName))
credentials := makeHmac(service, []byte("aws4_request"))
signature := makeHmac(credentials, []byte(ctx.stringToSign))
ctx.signature = hex.EncodeToString(signature) ctx.signature = hex.EncodeToString(signature)
} }
@ -674,9 +694,12 @@ func (ctx *signingCtx) buildBodyDigest() error {
if hash == "" { if hash == "" {
includeSHA256Header := ctx.unsignedPayload || includeSHA256Header := ctx.unsignedPayload ||
ctx.ServiceName == "s3" || ctx.ServiceName == "s3" ||
ctx.ServiceName == "s3-object-lambda" ||
ctx.ServiceName == "glacier" ctx.ServiceName == "glacier"
s3Presign := ctx.isPresign && ctx.ServiceName == "s3" s3Presign := ctx.isPresign &&
(ctx.ServiceName == "s3" ||
ctx.ServiceName == "s3-object-lambda")
if ctx.unsignedPayload || s3Presign { if ctx.unsignedPayload || s3Presign {
hash = "UNSIGNED-PAYLOAD" hash = "UNSIGNED-PAYLOAD"
@ -687,7 +710,11 @@ func (ctx *signingCtx) buildBodyDigest() error {
if !aws.IsReaderSeekable(ctx.Body) { if !aws.IsReaderSeekable(ctx.Body) {
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body) return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
} }
hash = hex.EncodeToString(makeSha256Reader(ctx.Body)) hashBytes, err := makeSha256Reader(ctx.Body)
if err != nil {
return err
}
hash = hex.EncodeToString(hashBytes)
} }
if includeSHA256Header { if includeSHA256Header {
@ -722,31 +749,45 @@ func (ctx *signingCtx) removePresign() {
ctx.Query.Del("X-Amz-SignedHeaders") ctx.Query.Del("X-Amz-SignedHeaders")
} }
func makeHmac(key []byte, data []byte) []byte { func hmacSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key) hash := hmac.New(sha256.New, key)
hash.Write(data) hash.Write(data)
return hash.Sum(nil) return hash.Sum(nil)
} }
func makeSha256(data []byte) []byte { func hashSHA256(data []byte) []byte {
hash := sha256.New() hash := sha256.New()
hash.Write(data) hash.Write(data)
return hash.Sum(nil) return hash.Sum(nil)
} }
func makeSha256Reader(reader io.ReadSeeker) []byte { func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
hash := sha256.New() hash := sha256.New()
start, _ := reader.Seek(0, sdkio.SeekCurrent) start, err := reader.Seek(0, sdkio.SeekCurrent)
defer reader.Seek(start, sdkio.SeekStart) if err != nil {
return nil, err
}
defer func() {
// ensure error is return if unable to seek back to start of payload.
_, err = reader.Seek(start, sdkio.SeekStart)
}()
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
size, err := aws.SeekerLen(reader)
if err != nil {
io.Copy(hash, reader) io.Copy(hash, reader)
return hash.Sum(nil) } else {
io.CopyN(hash, reader, size)
}
return hash.Sum(nil), nil
} }
const doubleSpace = " " const doubleSpace = " "
// stripExcessSpaces will rewrite the passed in slice's string values to not // stripExcessSpaces will rewrite the passed in slice's string values to not
// contain muliple side-by-side spaces. // contain multiple side-by-side spaces.
func stripExcessSpaces(vals []string) { func stripExcessSpaces(vals []string) {
var j, k, l, m, spaces int var j, k, l, m, spaces int
for i, str := range vals { for i, str := range vals {
@ -786,3 +827,28 @@ func stripExcessSpaces(vals []string) {
vals[i] = string(buf[:m]) vals[i] = string(buf[:m])
} }
} }
func buildSigningScope(region, service string, dt time.Time) string {
return strings.Join([]string{
formatShortTime(dt),
region,
service,
awsV4Request,
}, "/")
}
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
kDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(service))
signingKey := hmacSHA256(kService, []byte(awsV4Request))
return signingKey
}
func formatShortTime(dt time.Time) string {
return dt.UTC().Format(shortTimeFormat)
}
func formatTime(dt time.Time) string {
return dt.UTC().Format(timeFormat)
}

View file

@ -2,18 +2,24 @@ package aws
import ( import (
"io" "io"
"strings"
"sync" "sync"
"github.com/aws/aws-sdk-go/internal/sdkio" "github.com/aws/aws-sdk-go/internal/sdkio"
) )
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should // ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
// only be used with an io.Reader that is also an io.Seeker. Doing so may // SDK to accept an io.Reader that is not also an io.Seeker for unsigned
// cause request signature errors, or request body's not sent for GET, HEAD // streaming payload API operations.
// and DELETE HTTP methods.
// //
// Deprecated: Should only be used with io.ReadSeeker. If using for // A ReadSeekCloser wrapping an nonseekable io.Reader used in an API
// S3 PutObject to stream content use s3manager.Uploader instead. // operation's input will prevent that operation being retried in the case of
// network errors, and cause operation requests to fail if the operation
// requires payload signing.
//
// Note: If using With S3 PutObject to stream an object upload The SDK's S3
// Upload manager (s3manager.Uploader) provides support for streaming with the
// ability to retry network errors.
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser { func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r} return ReaderSeekerCloser{r}
} }
@ -43,7 +49,8 @@ func IsReaderSeekable(r io.Reader) bool {
// Read reads from the reader up to size of p. The number of bytes read, and // Read reads from the reader up to size of p. The number of bytes read, and
// error if it occurred will be returned. // error if it occurred will be returned.
// //
// If the reader is not an io.Reader zero bytes read, and nil error will be returned. // If the reader is not an io.Reader zero bytes read, and nil error will be
// returned.
// //
// Performs the same functionality as io.Reader Read // Performs the same functionality as io.Reader Read
func (r ReaderSeekerCloser) Read(p []byte) (int, error) { func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
@ -199,3 +206,59 @@ func (b *WriteAtBuffer) Bytes() []byte {
defer b.m.Unlock() defer b.m.Unlock()
return b.buf return b.buf
} }
// MultiCloser is a utility to close multiple io.Closers within a single
// statement.
type MultiCloser []io.Closer
// Close closes all of the io.Closers making up the MultiClosers. Any
// errors that occur while closing will be returned in the order they
// occur.
func (m MultiCloser) Close() error {
var errs errors
for _, c := range m {
err := c.Close()
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errs
}
return nil
}
type errors []error
func (es errors) Error() string {
var parts []string
for _, e := range es {
parts = append(parts, e.Error())
}
return strings.Join(parts, "\n")
}
// CopySeekableBody copies the seekable body to an io.Writer
func CopySeekableBody(dst io.Writer, src io.ReadSeeker) (int64, error) {
curPos, err := src.Seek(0, sdkio.SeekCurrent)
if err != nil {
return 0, err
}
// copy errors may be assumed to be from the body.
n, err := io.Copy(dst, src)
if err != nil {
return n, err
}
// seek back to the first position after reading to reset
// the body for transmission.
_, err = src.Seek(curPos, sdkio.SeekStart)
if err != nil {
return n, err
}
return n, nil
}

View file

@ -1,3 +1,4 @@
//go:build go1.8
// +build go1.8 // +build go1.8
package aws package aws

View file

@ -1,3 +1,4 @@
//go:build !go1.8
// +build !go1.8 // +build !go1.8
package aws package aws

View file

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go" const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK // SDKVersion is the version of this SDK
const SDKVersion = "1.15.11" const SDKVersion = "1.43.16"

View file

@ -1,6 +1,9 @@
module github.com/aws/aws-sdk-go module github.com/aws/aws-sdk-go
require ( require (
github.com/go-ini/ini v1.25.4 github.com/jmespath/go-jmespath v0.4.0
github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8 github.com/pkg/errors v0.9.1
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
) )
go 1.11

View file

@ -1,6 +1,7 @@
//go:build !go1.7
// +build !go1.7 // +build !go1.7
package aws package context
import "time" import "time"
@ -30,12 +31,11 @@ func (*emptyCtx) Value(key interface{}) interface{} {
func (e *emptyCtx) String() string { func (e *emptyCtx) String() string {
switch e { switch e {
case backgroundCtx: case BackgroundCtx:
return "aws.BackgroundContext" return "aws.BackgroundContext"
} }
return "unknown empty Context" return "unknown empty Context"
} }
var ( // BackgroundCtx is the common base context.
backgroundCtx = new(emptyCtx) var BackgroundCtx = new(emptyCtx)
)

120
vendor/github.com/aws/aws-sdk-go/internal/ini/ast.go generated vendored Normal file
View file

@ -0,0 +1,120 @@
package ini
// ASTKind represents different states in the parse table
// and the type of AST that is being constructed
type ASTKind int
// ASTKind* is used in the parse table to transition between
// the different states
const (
ASTKindNone = ASTKind(iota)
ASTKindStart
ASTKindExpr
ASTKindEqualExpr
ASTKindStatement
ASTKindSkipStatement
ASTKindExprStatement
ASTKindSectionStatement
ASTKindNestedSectionStatement
ASTKindCompletedNestedSectionStatement
ASTKindCommentStatement
ASTKindCompletedSectionStatement
)
func (k ASTKind) String() string {
switch k {
case ASTKindNone:
return "none"
case ASTKindStart:
return "start"
case ASTKindExpr:
return "expr"
case ASTKindStatement:
return "stmt"
case ASTKindSectionStatement:
return "section_stmt"
case ASTKindExprStatement:
return "expr_stmt"
case ASTKindCommentStatement:
return "comment"
case ASTKindNestedSectionStatement:
return "nested_section_stmt"
case ASTKindCompletedSectionStatement:
return "completed_stmt"
case ASTKindSkipStatement:
return "skip"
default:
return ""
}
}
// AST interface allows us to determine what kind of node we
// are on and casting may not need to be necessary.
//
// The root is always the first node in Children
type AST struct {
Kind ASTKind
Root Token
RootToken bool
Children []AST
}
func newAST(kind ASTKind, root AST, children ...AST) AST {
return AST{
Kind: kind,
Children: append([]AST{root}, children...),
}
}
func newASTWithRootToken(kind ASTKind, root Token, children ...AST) AST {
return AST{
Kind: kind,
Root: root,
RootToken: true,
Children: children,
}
}
// AppendChild will append to the list of children an AST has.
func (a *AST) AppendChild(child AST) {
a.Children = append(a.Children, child)
}
// GetRoot will return the root AST which can be the first entry
// in the children list or a token.
func (a *AST) GetRoot() AST {
if a.RootToken {
return *a
}
if len(a.Children) == 0 {
return AST{}
}
return a.Children[0]
}
// GetChildren will return the current AST's list of children
func (a *AST) GetChildren() []AST {
if len(a.Children) == 0 {
return []AST{}
}
if a.RootToken {
return a.Children
}
return a.Children[1:]
}
// SetChildren will set and override all children of the AST.
func (a *AST) SetChildren(children []AST) {
if a.RootToken {
a.Children = children
} else {
a.Children = append(a.Children[:1], children...)
}
}
// Start is used to indicate the starting state of the parse table.
var Start = newAST(ASTKindStart, AST{})

View file

@ -0,0 +1,11 @@
package ini
var commaRunes = []rune(",")
func isComma(b rune) bool {
return b == ','
}
func newCommaToken() Token {
return newToken(TokenComma, commaRunes, NoneType)
}

View file

@ -0,0 +1,35 @@
package ini
// isComment will return whether or not the next byte(s) is a
// comment.
func isComment(b []rune) bool {
if len(b) == 0 {
return false
}
switch b[0] {
case ';':
return true
case '#':
return true
}
return false
}
// newCommentToken will create a comment token and
// return how many bytes were read.
func newCommentToken(b []rune) (Token, int, error) {
i := 0
for ; i < len(b); i++ {
if b[i] == '\n' {
break
}
if len(b)-i > 2 && b[i] == '\r' && b[i+1] == '\n' {
break
}
}
return newToken(TokenComment, b[:i], NoneType), i, nil
}

42
vendor/github.com/aws/aws-sdk-go/internal/ini/doc.go generated vendored Normal file
View file

@ -0,0 +1,42 @@
// Package ini is an LL(1) parser for configuration files.
//
// Example:
// sections, err := ini.OpenFile("/path/to/file")
// if err != nil {
// panic(err)
// }
//
// profile := "foo"
// section, ok := sections.GetSection(profile)
// if !ok {
// fmt.Printf("section %q could not be found", profile)
// }
//
// Below is the BNF that describes this parser
// Grammar:
// stmt -> section | stmt'
// stmt' -> epsilon | expr
// expr -> value (stmt)* | equal_expr (stmt)*
// equal_expr -> value ( ':' | '=' ) equal_expr'
// equal_expr' -> number | string | quoted_string
// quoted_string -> " quoted_string'
// quoted_string' -> string quoted_string_end
// quoted_string_end -> "
//
// section -> [ section'
// section' -> section_value section_close
// section_value -> number | string_subset | boolean | quoted_string_subset
// quoted_string_subset -> " quoted_string_subset'
// quoted_string_subset' -> string_subset quoted_string_end
// quoted_string_subset -> "
// section_close -> ]
//
// value -> number | string_subset | boolean
// string -> ? UTF-8 Code-Points except '\n' (U+000A) and '\r\n' (U+000D U+000A) ?
// string_subset -> ? Code-points excepted by <string> grammar except ':' (U+003A), '=' (U+003D), '[' (U+005B), and ']' (U+005D) ?
//
// SkipState will skip (NL WS)+
//
// comment -> # comment' | ; comment'
// comment' -> epsilon | value
package ini

View file

@ -0,0 +1,4 @@
package ini
// emptyToken is used to satisfy the Token interface
var emptyToken = newToken(TokenNone, []rune{}, NoneType)

View file

@ -0,0 +1,24 @@
package ini
// newExpression will return an expression AST.
// Expr represents an expression
//
// grammar:
// expr -> string | number
func newExpression(tok Token) AST {
return newASTWithRootToken(ASTKindExpr, tok)
}
func newEqualExpr(left AST, tok Token) AST {
return newASTWithRootToken(ASTKindEqualExpr, tok, left)
}
// EqualExprKey will return a LHS value in the equal expr
func EqualExprKey(ast AST) string {
children := ast.GetChildren()
if len(children) == 0 || ast.Kind != ASTKindEqualExpr {
return ""
}
return string(children[0].Root.Raw())
}

Some files were not shown because too many files have changed in this diff Show more