Switch to using the dep tool and update all the dependencies

This commit is contained in:
Nick Craig-Wood 2017-05-11 15:39:54 +01:00
parent 5135ff73cb
commit 98c2d2c41b
5321 changed files with 4483201 additions and 5922 deletions

View file

@ -0,0 +1,14 @@
Please fill out the sections below to help us address your issue.
### Version of AWS SDK for Go?
### Version of Go (`go version`)?
### What issue did you see?
### Steps to reproduce
If you have have an runnable example, please include it.

View file

@ -0,0 +1,3 @@
For changes to files under the `/model/` folder, and manual edits to autogenerated code (e.g. `/service/s3/api.go`) please create an Issue instead of a PR for those type of changes.
If there is an existing bug or feature this PR is answers please reference it here.

11
vendor/github.com/aws/aws-sdk-go/.gitignore generated vendored Normal file
View file

@ -0,0 +1,11 @@
dist
/doc
/doc-staging
.yardoc
Gemfile.lock
awstesting/integration/smoke/**/importmarker__.go
awstesting/integration/smoke/_test/
/vendor/bin/
/vendor/pkg/
/vendor/src/
/private/model/cli/gen-api/gen-api

14
vendor/github.com/aws/aws-sdk-go/.godoc_config generated vendored Normal file
View file

@ -0,0 +1,14 @@
{
"PkgHandler": {
"Pattern": "/sdk-for-go/api/",
"StripPrefix": "/sdk-for-go/api",
"Include": ["/src/github.com/aws/aws-sdk-go/aws", "/src/github.com/aws/aws-sdk-go/service"],
"Exclude": ["/src/cmd", "/src/github.com/aws/aws-sdk-go/awstesting", "/src/github.com/aws/aws-sdk-go/awsmigrate"],
"IgnoredSuffixes": ["iface"]
},
"Github": {
"Tag": "master",
"Repo": "/aws/aws-sdk-go",
"UseGithub": true
}
}

24
vendor/github.com/aws/aws-sdk-go/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,24 @@
language: go
sudo: false
go:
- 1.5
- 1.6
- 1.7
- 1.8
- tip
# Use Go 1.5's vendoring experiment for 1.5 tests.
env:
- GO15VENDOREXPERIMENT=1
install:
- make get-deps
script:
- make unit-with-race-cover
matrix:
allow_failures:
- go: tip

1069
vendor/github.com/aws/aws-sdk-go/CHANGELOG.md generated vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,5 @@
### SDK Features
### SDK Enhancements
### SDK Bugs

127
vendor/github.com/aws/aws-sdk-go/CONTRIBUTING.md generated vendored Normal file
View file

@ -0,0 +1,127 @@
Contributing to the AWS SDK for Go
We work hard to provide a high-quality and useful SDK, and we greatly value
feedback and contributions from our community. Whether it's a bug report,
new feature, correction, or additional documentation, we welcome your issues
and pull requests. Please read through this document before submitting any
issues or pull requests to ensure we have all the necessary information to
effectively respond to your bug report or contribution.
## Filing Bug Reports
You can file bug reports against the SDK on the [GitHub issues][issues] page.
If you are filing a report for a bug or regression in the SDK, it's extremely
helpful to provide as much information as possible when opening the original
issue. This helps us reproduce and investigate the possible bug without having
to wait for this extra information to be provided. Please read the following
guidelines prior to filing a bug report.
1. Search through existing [issues][] to ensure that your specific issue has
not yet been reported. If it is a common issue, it is likely there is
already a bug report for your problem.
2. Ensure that you have tested the latest version of the SDK. Although you
may have an issue against an older version of the SDK, we cannot provide
bug fixes for old versions. It's also possible that the bug may have been
fixed in the latest release.
3. Provide as much information about your environment, SDK version, and
relevant dependencies as possible. For example, let us know what version
of Go you are using, which and version of the operating system, and the
the environment your code is running in. e.g Container.
4. Provide a minimal test case that reproduces your issue or any error
information you related to your problem. We can provide feedback much
more quickly if we know what operations you are calling in the SDK. If
you cannot provide a full test case, provide as much code as you can
to help us diagnose the problem. Any relevant information should be provided
as well, like whether this is a persistent issue, or if it only occurs
some of the time.
## Submitting Pull Requests
We are always happy to receive code and documentation contributions to the SDK.
Please be aware of the following notes prior to opening a pull request:
1. The SDK is released under the [Apache license][license]. Any code you submit
will be released under that license. For substantial contributions, we may
ask you to sign a [Contributor License Agreement (CLA)][cla].
2. If you would like to implement support for a significant feature that is not
yet available in the SDK, please talk to us beforehand to avoid any
duplication of effort.
3. Wherever possible, pull requests should contain tests as appropriate.
Bugfixes should contain tests that exercise the corrected behavior (i.e., the
test should fail without the bugfix and pass with it), and new features
should be accompanied by tests exercising the feature.
4. Pull requests that contain failing tests will not be merged until the test
failures are addressed. Pull requests that cause a significant drop in the
SDK's test coverage percentage are unlikely to be merged until tests have
been added.
5. The JSON files under the SDK's `models` folder are sourced from outside the SDK.
Such as `models/apis/ec2/2016-11-15/api.json`. We will not accept pull requests
directly on these models. If you discover an issue with the models please
create a Github [issue](issues) describing the issue.
### Testing
To run the tests locally, running the `make unit` command will `go get` the
SDK's testing dependencies, and run vet, link and unit tests for the SDK.
```
make unit
```
Standard go testing functionality is supported as well. To test SDK code that
is tagged with `codegen` you'll need to set the build tag in the go test
command. The `make unit` command will do this automatically.
```
go test -tags codegen ./private/...
```
See the `Makefile` for additional testing tags that can be used in testing.
To test on multiple platform the SDK includes several DockerFiles under the
`awstesting/sandbox` folder, and associated make recipes to to execute
unit testing within environments configured for specific Go versions.
```
make sandbox-test-go18
```
To run all sandbox environments use the following make recipe
```
# Optionally update the Go tip that will be used during the batch testing
make update-aws-golang-tip
# Run all SDK tests for supported Go versions in sandboxes
make sandbox-test
```
In addition the sandbox environment include make recipes for interactive modes
so you can run command within the Docker container and context of the SDK.
```
make sandbox-go18
```
### Changelog
You can see all release changes in the `CHANGELOG.md` file at the root of the
repository. The release notes added to this file will contain service client
updates, and major SDK changes.
[issues]: https://github.com/aws/aws-sdk-go/issues
[pr]: https://github.com/aws/aws-sdk-go/pulls
[license]: http://aws.amazon.com/apache2.0/
[cla]: http://en.wikipedia.org/wiki/Contributor_License_Agreement
[releasenotes]: https://github.com/aws/aws-sdk-go/releases

179
vendor/github.com/aws/aws-sdk-go/Makefile generated vendored Normal file
View file

@ -0,0 +1,179 @@
LINTIGNOREDOT='awstesting/integration.+should not use dot imports'
LINTIGNOREDOC='service/[^/]+/(api|service|waiters)\.go:.+(comment on exported|should have comment or be unexported)'
LINTIGNORECONST='service/[^/]+/(api|service|waiters)\.go:.+(type|struct field|const|func) ([^ ]+) should be ([^ ]+)'
LINTIGNORESTUTTER='service/[^/]+/(api|service)\.go:.+(and that stutters)'
LINTIGNOREINFLECT='service/[^/]+/(api|errors|service)\.go:.+(method|const) .+ should be '
LINTIGNOREINFLECTS3UPLOAD='service/s3/s3manager/upload\.go:.+struct field SSEKMSKeyId should be '
LINTIGNOREDEPS='vendor/.+\.go'
LINTIGNOREPKGCOMMENT='service/[^/]+/doc_custom.go:.+package comment should be of the form'
UNIT_TEST_TAGS="example codegen"
SDK_WITH_VENDOR_PKGS=$(shell go list -tags ${UNIT_TEST_TAGS} ./... | grep -v "/vendor/src")
SDK_ONLY_PKGS=$(shell go list ./... | grep -v "/vendor/")
SDK_UNIT_TEST_ONLY_PKGS=$(shell go list -tags ${UNIT_TEST_TAGS} ./... | grep -v "/vendor/")
SDK_GO_1_4=$(shell go version | grep "go1.4")
SDK_GO_1_5=$(shell go version | grep "go1.5")
SDK_GO_VERSION=$(shell go version | awk '''{print $$3}''' | tr -d '''\n''')
all: get-deps generate unit
help:
@echo "Please use \`make <target>' where <target> is one of"
@echo " api_info to print a list of services and versions"
@echo " docs to build SDK documentation"
@echo " build to go build the SDK"
@echo " unit to run unit tests"
@echo " integration to run integration tests"
@echo " performance to run performance tests"
@echo " verify to verify tests"
@echo " lint to lint the SDK"
@echo " vet to vet the SDK"
@echo " generate to go generate and make services"
@echo " gen-test to generate protocol tests"
@echo " gen-services to generate services"
@echo " get-deps to go get the SDK dependencies"
@echo " get-deps-tests to get the SDK's test dependencies"
@echo " get-deps-verify to get the SDK's verification dependencies"
generate: gen-test gen-endpoints gen-services
gen-test: gen-protocol-test
gen-services:
go generate ./service
gen-protocol-test:
go generate ./private/protocol/...
gen-endpoints:
go generate ./models/endpoints/
build:
@echo "go build SDK and vendor packages"
@go build ${SDK_ONLY_PKGS}
unit: get-deps-tests build verify
@echo "go test SDK and vendor packages"
@go test -tags ${UNIT_TEST_TAGS} $(SDK_UNIT_TEST_ONLY_PKGS)
unit-with-race-cover: get-deps-tests build verify
@echo "go test SDK and vendor packages"
@go test -tags ${UNIT_TEST_TAGS} -race -cpu=1,2,4 $(SDK_UNIT_TEST_ONLY_PKGS)
integration: get-deps-tests integ-custom smoke-tests performance
integ-custom:
go test -tags "integration" ./awstesting/integration/customizations/...
cleanup-integ:
go run -tags "integration" ./awstesting/cmd/bucket_cleanup/main.go "aws-sdk-go-integration"
smoke-tests: get-deps-tests
gucumber -go-tags "integration" ./awstesting/integration/smoke
performance: get-deps-tests
AWS_TESTING_LOG_RESULTS=${log-detailed} AWS_TESTING_REGION=$(region) AWS_TESTING_DB_TABLE=$(table) gucumber -go-tags "integration" ./awstesting/performance
sandbox-tests: sandbox-test-go15 sandbox-test-go15-novendorexp sandbox-test-go16 sandbox-test-go17 sandbox-test-go18 sandbox-test-gotip
sandbox-build-go15:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.5 -t "aws-sdk-go-1.5" .
sandbox-go15: sandbox-build-go15
docker run -i -t aws-sdk-go-1.5 bash
sandbox-test-go15: sandbox-build-go15
docker run -t aws-sdk-go-1.5
sandbox-build-go15-novendorexp:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.5-novendorexp -t "aws-sdk-go-1.5-novendorexp" .
sandbox-go15-novendorexp: sandbox-build-go15-novendorexp
docker run -i -t aws-sdk-go-1.5-novendorexp bash
sandbox-test-go15-novendorexp: sandbox-build-go15-novendorexp
docker run -t aws-sdk-go-1.5-novendorexp
sandbox-build-go16:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.6 -t "aws-sdk-go-1.6" .
sandbox-go16: sandbox-build-go16
docker run -i -t aws-sdk-go-1.6 bash
sandbox-test-go16: sandbox-build-go16
docker run -t aws-sdk-go-1.6
sandbox-build-go17:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.7 -t "aws-sdk-go-1.7" .
sandbox-go17: sandbox-build-go17
docker run -i -t aws-sdk-go-1.7 bash
sandbox-test-go17: sandbox-build-go17
docker run -t aws-sdk-go-1.7
sandbox-build-go18:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.8 -t "aws-sdk-go-1.8" .
sandbox-go18: sandbox-build-go18
docker run -i -t aws-sdk-go-1.8 bash
sandbox-test-go18: sandbox-build-go18
docker run -t aws-sdk-go-1.8
sandbox-build-gotip:
@echo "Run make update-aws-golang-tip, if this test fails because missing aws-golang:tip container"
docker build -f ./awstesting/sandbox/Dockerfile.test.gotip -t "aws-sdk-go-tip" .
sandbox-gotip: sandbox-build-gotip
docker run -i -t aws-sdk-go-tip bash
sandbox-test-gotip: sandbox-build-gotip
docker run -t aws-sdk-go-tip
update-aws-golang-tip:
docker build --no-cache=true -f ./awstesting/sandbox/Dockerfile.golang-tip -t "aws-golang:tip" .
verify: get-deps-verify lint vet
lint:
@echo "go lint SDK and vendor packages"
@lint=`if [ \( -z "${SDK_GO_1_4}" \) -a \( -z "${SDK_GO_1_5}" \) ]; then golint ./...; else echo "skipping golint"; fi`; \
lint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOT} -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS} -e ${LINTIGNOREINFLECTS3UPLOAD} -e ${LINTIGNOREPKGCOMMENT}`; \
echo "$$lint"; \
if [ "$$lint" != "" ] && [ "$$lint" != "skipping golint" ]; then exit 1; fi
SDK_BASE_FOLDERS=$(shell ls -d */ | grep -v vendor | grep -v awsmigrate)
ifneq (,$(findstring go1.4, ${SDK_GO_VERSION}))
GO_VET_CMD=echo skipping go vet, ${SDK_GO_VERSION}
else ifneq (,$(findstring go1.6, ${SDK_GO_VERSION}))
GO_VET_CMD=go tool vet --all -shadow -example=false
else
GO_VET_CMD=go tool vet --all -shadow
endif
vet:
${GO_VET_CMD} ${SDK_BASE_FOLDERS}
get-deps: get-deps-tests get-deps-verify
@echo "go get SDK dependencies"
@go get -v $(SDK_ONLY_PKGS)
get-deps-tests:
@echo "go get SDK testing dependencies"
go get github.com/gucumber/gucumber/cmd/gucumber
go get github.com/stretchr/testify
go get github.com/smartystreets/goconvey
go get golang.org/x/net/html
get-deps-verify:
@echo "go get SDK verification utilities"
@if [ \( -z "${SDK_GO_1_4}" \) -a \( -z "${SDK_GO_1_5}" \) ]; then go get github.com/golang/lint/golint; else echo "skipped getting golint"; fi
bench:
@echo "go bench SDK packages"
@go test -run NONE -bench . -benchmem -tags 'bench' $(SDK_ONLY_PKGS)
bench-protocol:
@echo "go bench SDK protocol marshallers"
@go test -run NONE -bench . -benchmem -tags 'bench' ./private/protocol/...
docs:
@echo "generate SDK docs"
@# This env variable, DOCS, is for internal use
@if [ -z ${AWS_DOC_GEN_TOOL} ]; then\
rm -rf doc && bundle install && bundle exec yard;\
else\
$(AWS_DOC_GEN_TOOL) `pwd`;\
fi
api_info:
@go run private/model/cli/api-info/api-info.go

449
vendor/github.com/aws/aws-sdk-go/README.md generated vendored Normal file
View file

@ -0,0 +1,449 @@
[![API Reference](http://img.shields.io/badge/api-reference-blue.svg)](http://docs.aws.amazon.com/sdk-for-go/api) [![Join the chat at https://gitter.im/aws/aws-sdk-go](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/aws/aws-sdk-go?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build Status](https://img.shields.io/travis/aws/aws-sdk-go.svg)](https://travis-ci.org/aws/aws-sdk-go) [![Apache V2 License](http://img.shields.io/badge/license-Apache%20V2-blue.svg)](https://github.com/aws/aws-sdk-go/blob/master/LICENSE.txt)
# AWS SDK for Go
aws-sdk-go is the official AWS SDK for the Go programming language.
Checkout our [release notes](https://github.com/aws/aws-sdk-go/releases) for information about the latest bug fixes, updates, and features added to the SDK.
## Installing
If you are using Go 1.5 with the `GO15VENDOREXPERIMENT=1` vendoring flag, or 1.6 and higher you can use the following command to retrieve the SDK. The SDK's non-testing dependencies will be included and are vendored in the `vendor` folder.
go get -u github.com/aws/aws-sdk-go
Otherwise if your Go environment does not have vendoring support enabled, or you do not want to include the vendored SDK's dependencies you can use the following command to retrieve the SDK and its non-testing dependencies using `go get`.
go get -u github.com/aws/aws-sdk-go/aws/...
go get -u github.com/aws/aws-sdk-go/service/...
If you're looking to retrieve just the SDK without any dependencies use the following command.
go get -d github.com/aws/aws-sdk-go/
These two processes will still include the `vendor` folder and it should be deleted if its not going to be used by your environment.
rm -rf $GOPATH/src/github.com/aws/aws-sdk-go/vendor
## Getting Help
Please use these community resources for getting help. We use the GitHub issues for tracking bugs and feature requests.
* Ask a question on [StackOverflow](http://stackoverflow.com/) and tag it with the [`aws-sdk-go`](http://stackoverflow.com/questions/tagged/aws-sdk-go) tag.
* Come join the AWS SDK for Go community chat on [gitter](https://gitter.im/aws/aws-sdk-go).
* Open a support ticket with [AWS Support](http://docs.aws.amazon.com/awssupport/latest/user/getting-started.html).
* If you think you may of found a bug, please open an [issue](https://github.com/aws/aws-sdk-go/issues/new).
## Opening Issues
If you encounter a bug with the AWS SDK for Go we would like to hear about it. Search the [existing issues](https://github.com/aws/aws-sdk-go/issues) and see if others are also experiencing the issue before opening a new issue. Please include the version of AWS SDK for Go, Go language, and OS youre using. Please also include repro case when appropriate.
The GitHub issues are intended for bug reports and feature requests. For help and questions with using AWS SDK for GO please make use of the resources listed in the [Getting Help](https://github.com/aws/aws-sdk-go#getting-help) section. Keeping the list of open issues lean will help us respond in a timely manner.
## Reference Documentation
[`Getting Started Guide`](https://aws.amazon.com/sdk-for-go/) - This document is a general introduction how to configure and make requests with the SDK. If this is your first time using the SDK, this documentation and the API documentation will help you get started. This document focuses on the syntax and behavior of the SDK. The [Service Developer Guide](https://aws.amazon.com/documentation/) will help you get started using specific AWS services.
[`SDK API Reference Documentation`](https://docs.aws.amazon.com/sdk-for-go/api/) - Use this document to look up all API operation input and output parameters for AWS services supported by the SDK. The API reference also includes documentation of the SDK, and examples how to using the SDK, service client API operations, and API operation require parameters.
[`Service Developer Guide`](https://aws.amazon.com/documentation/) - Use this documentation to learn how to interface with an AWS service. These are great guides both, if you're getting started with a service, or looking for more information on a service. You should not need this document for coding, though in some cases, services may supply helpful samples that you might want to look out for.
[`SDK Examples`](https://github.com/aws/aws-sdk-go/tree/master/example) - Included in the SDK's repo are a several hand crafted examples using the SDK features and AWS services.
## Overview of SDK's Packages
The SDK is composed of two main components, SDK core, and service clients.
The SDK core packages are all available under the aws package at the root of
the SDK. Each client for a supported AWS service is available within its own
package under the service folder at the root of the SDK.
* aws - SDK core, provides common shared types such as Config, Logger,
and utilities to make working with API parameters easier.
* awserr - Provides the error interface that the SDK will use for all
errors that occur in the SDK's processing. This includes service API
response errors as well. The Error type is made up of a code and message.
Cast the SDK's returned error type to awserr.Error and call the Code
method to compare returned error to specific error codes. See the package's
documentation for additional values that can be extracted such as RequestID.
* credentials - Provides the types and built in credentials providers
the SDK will use to retrieve AWS credentials to make API requests with.
Nested under this folder are also additional credentials providers such as
stscreds for assuming IAM roles, and ec2rolecreds for EC2 Instance roles.
* endpoints - Provides the AWS Regions and Endpoints metadata for the SDK.
Use this to lookup AWS service endpoint information such as which services
are in a region, and what regions a service is in. Constants are also provided
for all region identifiers, e.g UsWest2RegionID for "us-west-2".
* session - Provides initial default configuration, and load
configuration from external sources such as environment and shared
credentials file.
* request - Provides the API request sending, and retry logic for the SDK.
This package also includes utilities for defining your own request
retryer, and configuring how the SDK processes the request.
* service - Clients for AWS services. All services supported by the SDK are
available under this folder.
## How to Use the SDK's AWS Service Clients
The SDK includes the Go types and utilities you can use to make requests to
AWS service APIs. Within the service folder at the root of the SDK you'll find
a package for each AWS service the SDK supports. All service clients follows
a common pattern of creation and usage.
When creating a client for an AWS service you'll first need to have a Session
value constructed. The Session provides shared configuration that can be shared
between your service clients. When service clients are created you can pass
in additional configuration via the aws.Config type to override configuration
provided by in the Session to create service client instances with custom
configuration.
Once the service's client is created you can use it to make API requests the
AWS service. These clients are safe to use concurrently.
## Configuring the SDK
In the AWS SDK for Go, you can configure settings for service clients, such
as the log level and maximum number of retries. Most settings are optional;
however, for each service client, you must specify a region and your credentials.
The SDK uses these values to send requests to the correct AWS region and sign
requests with the correct credentials. You can specify these values as part
of a session or as environment variables.
See the SDK's [configuration guide][config_guide] for more information.
See the [session][session_pkg] package documentation for more information on how to use Session
with the SDK.
See the [Config][config_typ] type in the [aws][aws_pkg] package for more information on configuration
options.
[config_guide]: https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/configuring-sdk.html
[session_pkg]: https://docs.aws.amazon.com/sdk-for-go/api/aws/session/
[config_typ]: https://docs.aws.amazon.com/sdk-for-go/api/aws/#Config
[aws_pkg]: https://docs.aws.amazon.com/sdk-for-go/api/aws/
### Configuring Credentials
When using the SDK you'll generally need your AWS credentials to authenticate
with AWS services. The SDK supports multiple methods of supporting these
credentials. By default the SDK will source credentials automatically from
its default credential chain. See the session package for more information
on this chain, and how to configure it. The common items in the credential
chain are the following:
* Environment Credentials - Set of environment variables that are useful
when sub processes are created for specific roles.
* Shared Credentials file (~/.aws/credentials) - This file stores your
credentials based on a profile name and is useful for local development.
* EC2 Instance Role Credentials - Use EC2 Instance Role to assign credentials
to application running on an EC2 instance. This removes the need to manage
credential files in production.
Credentials can be configured in code as well by setting the Config's Credentials
value to a custom provider or using one of the providers included with the
SDK to bypass the default credential chain and use a custom one. This is
helpful when you want to instruct the SDK to only use a specific set of
credentials or providers.
This example creates a credential provider for assuming an IAM role, "myRoleARN"
and configures the S3 service client to use that role for API requests.
```go
// Initial credentials loaded from SDK's default credential chain. Such as
// the environment, shared credentials (~/.aws/credentials), or EC2 Instance
// Role. These credentials will be used to to make the STS Assume Role API.
sess := session.Must(session.NewSession())
// Create the credentials from AssumeRoleProvider to assume the role
// referenced by the "myRoleARN" ARN.
creds := stscreds.NewCredentials(sess, "myRoleArn")
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess, &aws.Config{Credentials: creds})/
```
See the [credentials][credentials_pkg] package documentation for more information on credential
providers included with the SDK, and how to customize the SDK's usage of
credentials.
The SDK has support for the shared configuration file (~/.aws/config). This
support can be enabled by setting the environment variable, "AWS_SDK_LOAD_CONFIG=1",
or enabling the feature in code when creating a Session via the
Option's SharedConfigState parameter.
```go
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
```
[credentials_pkg]: ttps://docs.aws.amazon.com/sdk-for-go/api/aws/credentials
### Configuring AWS Region
In addition to the credentials you'll need to specify the region the SDK
will use to make AWS API requests to. In the SDK you can specify the region
either with an environment variable, or directly in code when a Session or
service client is created. The last value specified in code wins if the region
is specified multiple ways.
To set the region via the environment variable set the "AWS_REGION" to the
region you want to the SDK to use. Using this method to set the region will
allow you to run your application in multiple regions without needing additional
code in the application to select the region.
AWS_REGION=us-west-2
The endpoints package includes constants for all regions the SDK knows. The
values are all suffixed with RegionID. These values are helpful, because they
reduce the need to type the region string manually.
To set the region on a Session use the aws package's Config struct parameter
Region to the AWS region you want the service clients created from the session to
use. This is helpful when you want to create multiple service clients, and
all of the clients make API requests to the same region.
```go
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(endpoints.UsWest2RegionID),
}))
```
See the [endpoints][endpoints_pkg] package for the AWS Regions and Endpoints metadata.
In addition to setting the region when creating a Session you can also set
the region on a per service client bases. This overrides the region of a
Session. This is helpful when you want to create service clients in specific
regions different from the Session's region.
```go
svc := s3.New(sess, &aws.Config{
Region: aws.String(endpoints.UsWest2RegionID),
})
```
See the [Config][config_typ] type in the [aws][aws_pkg] package for more information and additional
options such as setting the Endpoint, and other service client configuration options.
[endpoints_pkg]: https://docs.aws.amazon.com/sdk-for-go/api/aws/endpoints/
## Making API Requests
Once the client is created you can make an API request to the service.
Each API method takes a input parameter, and returns the service response
and an error. The SDK provides methods for making the API call in multiple ways.
In this list we'll use the S3 ListObjects API as an example for the different
ways of making API requests.
* ListObjects - Base API operation that will make the API request to the service.
* ListObjectsRequest - API methods suffixed with Request will construct the
API request, but not send it. This is also helpful when you want to get a
presigned URL for a request, and share the presigned URL instead of your
application making the request directly.
* ListObjectsPages - Same as the base API operation, but uses a callback to
automatically handle pagination of the API's response.
* ListObjectsWithContext - Same as base API operation, but adds support for
the Context pattern. This is helpful for controlling the canceling of in
flight requests. See the Go standard library context package for more
information. This method also takes request package's Option functional
options as the variadic argument for modifying how the request will be
made, or extracting information from the raw HTTP response.
* ListObjectsPagesWithContext - same as ListObjectsPages, but adds support for
the Context pattern. Similar to ListObjectsWithContext this method also
takes the request package's Option function option types as the variadic
argument.
In addition to the API operations the SDK also includes several higher level
methods that abstract checking for and waiting for an AWS resource to be in
a desired state. In this list we'll use WaitUntilBucketExists to demonstrate
the different forms of waiters.
* WaitUntilBucketExists. - Method to make API request to query an AWS service for
a resource's state. Will return successfully when that state is accomplished.
* WaitUntilBucketExistsWithContext - Same as WaitUntilBucketExists, but adds
support for the Context pattern. In addition these methods take request
package's WaiterOptions to configure the waiter, and how underlying request
will be made by the SDK.
The API method will document which error codes the service might return for
the operation. These errors will also be available as const strings prefixed
with "ErrCode" in the service client's package. If there are no errors listed
in the API's SDK documentation you'll need to consult the AWS service's API
documentation for the errors that could be returned.
```go
ctx := context.Background()
result, err := svc.GetObjectWithContext(ctx, &s3.GetObjectInput{
Bucket: aws.String("my-bucket"),
Key: aws.String("my-key"),
})
if err != nil {
// Cast err to awserr.Error to handle specific error codes.
aerr, ok := err.(awserr.Error)
if ok && aerr.Code() == s3.ErrCodeNoSuchKey {
// Specific error code handling
}
return err
}
// Make sure to close the body when done with it for S3 GetObject APIs or
// will leak connections.
defer result.Body.Close()
fmt.Println("Object Size:", aws.StringValue(result.ContentLength))
```
### API Request Pagination and Resource Waiters
Pagination helper methods are suffixed with "Pages", and provide the
functionality needed to round trip API page requests. Pagination methods
take a callback function that will be called for each page of the API's response.
```go
objects := []string{}
err := svc.ListObjectsPagesWithContext(ctx, &s3.ListObjectsInput{
Bucket: aws.String(myBucket),
}, func(p *s3.ListObjectsOutput, lastPage bool) bool {
for _, o := range p.Contents {
objects = append(objects, aws.StringValue(o.Key))
}
return true // continue paging
})
if err != nil {
panic(fmt.Sprintf("failed to list objects for bucket, %s, %v", myBucket, err))
}
fmt.Println("Objects in bucket:", objects)
```
Waiter helper methods provide the functionality to wait for an AWS resource
state. These methods abstract the logic needed to to check the state of an
AWS resource, and wait until that resource is in a desired state. The waiter
will block until the resource is in the state that is desired, an error occurs,
or the waiter times out. If a resource times out the error code returned will
be request.WaiterResourceNotReadyErrorCode.
```go
err := svc.WaitUntilBucketExistsWithContext(ctx, &s3.HeadBucketInput{
Bucket: aws.String(myBucket),
})
if err != nil {
aerr, ok := err.(awserr.Error)
if ok && aerr.Code() == request.WaiterResourceNotReadyErrorCode {
fmt.Fprintf(os.Stderr, "timed out while waiting for bucket to exist")
}
panic(fmt.Errorf("failed to wait for bucket to exist, %v", err))
}
fmt.Println("Bucket", myBucket, "exists")
```
## Complete SDK Example
This example shows a complete working Go file which will upload a file to S3
and use the Context pattern to implement timeout logic that will cancel the
request if it takes too long. This example highlights how to use sessions,
create a service client, make a request, handle the error, and process the
response.
```go
package main
import (
"context"
"flag"
"fmt"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)
// Uploads a file to S3 given a bucket and object key. Also takes a duration
// value to terminate the update if it doesn't complete within that time.
//
// The AWS Region needs to be provided in the AWS shared config or on the
// environment variable as `AWS_REGION`. Credentials also must be provided
// Will default to shared config file, but can load from environment if provided.
//
// Usage:
// # Upload myfile.txt to myBucket/myKey. Must complete within 10 minutes or will fail
// go run withContext.go -b mybucket -k myKey -d 10m < myfile.txt
func main() {
var bucket, key string
var timeout time.Duration
flag.StringVar(&bucket, "b", "", "Bucket name.")
flag.StringVar(&key, "k", "", "Object key name.")
flag.DurationVar(&timeout, "d", 0, "Upload timeout.")
flag.Parse()
// All clients require a Session. The Session provides the client with
// shared configuration such as region, endpoint, and credentials. A
// Session should be shared where possible to take advantage of
// configuration and credential caching. See the session package for
// more information.
sess := session.Must(session.NewSession())
// Create a new instance of the service's client with a Session.
// Optional aws.Config values can also be provided as variadic arguments
// to the New function. This option allows you to provide service
// specific configuration.
svc := s3.New(sess)
// Create a context with a timeout that will abort the upload if it takes
// more than the passed in timeout.
ctx := context.Background()
var cancelFn func()
if timeout > 0 {
ctx, cancelFn = context.WithTimeout(ctx, timeout)
}
// Ensure the context is canceled to prevent leaking.
// See context package for more information, https://golang.org/pkg/context/
defer cancelFn()
// Uploads the object to S3. The Context will interrupt the request if the
// timeout expires.
_, err := svc.PutObjectWithContext(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Body: os.Stdin,
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == request.CanceledErrorCode {
// If the SDK can determine the request or retry delay was canceled
// by a context the CanceledErrorCode error code will be returned.
fmt.Fprintf(os.Stderr, "upload canceled due to timeout, %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "failed to upload object, %v\n", err)
}
os.Exit(1)
}
fmt.Printf("successfully uploaded file to %s/%s\n", bucket, key)
}
```
## License
This SDK is distributed under the
[Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0),
see LICENSE.txt and NOTICE.txt for more information.

View file

@ -0,0 +1,265 @@
package awsutil_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func ExampleCopy() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Print the result
fmt.Println(awsutil.Prettify(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}
func TestCopy1(t *testing.T) {
type Bar struct {
a *int
B *int
c int
D int
}
type Foo struct {
A int
B []*string
C map[string]*int
D *time.Time
E *Bar
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
intPtr1 := 1
intPtr2 := 2
now := time.Now()
f1 := &Foo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
D: &now,
E: &Bar{
&intPtr1,
&intPtr2,
2,
3,
},
}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, f2.D, f1.D)
assert.Equal(t, f2.E.B, f1.E.B)
assert.Equal(t, f2.E.D, f1.E.D)
// But pointers are not!
str3 := "nothello"
int3 := 57
f2.A = 100
*f2.B[0] = str3
*f2.C["B"] = int3
*f2.D = time.Now()
f2.E.a = &int3
*f2.E.B = int3
f2.E.c = 5
f2.E.D = 5
assert.NotEqual(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.C, f1.C)
assert.NotEqual(t, f2.D, f1.D)
assert.NotEqual(t, f2.E.a, f1.E.a)
assert.NotEqual(t, f2.E.B, f1.E.B)
assert.NotEqual(t, f2.E.c, f1.E.c)
assert.NotEqual(t, f2.E.D, f1.E.D)
}
func TestCopyNestedWithUnexported(t *testing.T) {
type Bar struct {
a int
B int
}
type Foo struct {
A string
B Bar
}
f1 := &Foo{A: "string", B: Bar{a: 1, B: 2}}
var f2 Foo
awsutil.Copy(&f2, f1)
// Values match
assert.Equal(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.B.a, f1.B.a)
assert.Equal(t, f2.B.B, f2.B.B)
}
func TestCopyIgnoreNilMembers(t *testing.T) {
type Foo struct {
A *string
B []string
C map[string]string
}
f := &Foo{}
assert.Nil(t, f.A)
assert.Nil(t, f.B)
assert.Nil(t, f.C)
var f2 Foo
awsutil.Copy(&f2, f)
assert.Nil(t, f2.A)
assert.Nil(t, f2.B)
assert.Nil(t, f2.C)
fcopy := awsutil.CopyOf(f)
f3 := fcopy.(*Foo)
assert.Nil(t, f3.A)
assert.Nil(t, f3.B)
assert.Nil(t, f3.C)
}
func TestCopyPrimitive(t *testing.T) {
str := "hello"
var s string
awsutil.Copy(&s, &str)
assert.Equal(t, "hello", s)
}
func TestCopyNil(t *testing.T) {
var s string
awsutil.Copy(&s, nil)
assert.Equal(t, "", s)
}
func TestCopyReader(t *testing.T) {
var buf io.Reader = bytes.NewReader([]byte("hello world"))
var r io.Reader
awsutil.Copy(&r, buf)
b, err := ioutil.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, []byte("hello world"), b)
// empty bytes because this is not a deep copy
b, err = ioutil.ReadAll(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(""), b)
}
func TestCopyDifferentStructs(t *testing.T) {
type SrcFoo struct {
A int
B []*string
C map[string]*int
SrcUnique string
SameNameDiffType int
unexportedPtr *int
ExportedPtr *int
}
type DstFoo struct {
A int
B []*string
C map[string]*int
DstUnique int
SameNameDiffType string
unexportedPtr *int
ExportedPtr *int
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
f1 := &SrcFoo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
SrcUnique: "unique",
SameNameDiffType: 1,
unexportedPtr: &int1,
ExportedPtr: &int2,
}
// Do the copy
var f2 DstFoo
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, "unique", f1.SrcUnique)
assert.Equal(t, 1, f1.SameNameDiffType)
assert.Equal(t, 0, f2.DstUnique)
assert.Equal(t, "", f2.SameNameDiffType)
assert.Equal(t, int1, *f1.unexportedPtr)
assert.Nil(t, f2.unexportedPtr)
assert.Equal(t, int2, *f1.ExportedPtr)
assert.Equal(t, int2, *f2.ExportedPtr)
}
func ExampleCopyOf() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
v := awsutil.CopyOf(f1)
var f2 *Foo = v.(*Foo)
// Print the result
fmt.Println(awsutil.Prettify(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}

View file

@ -0,0 +1,29 @@
package awsutil_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func TestDeepEqual(t *testing.T) {
cases := []struct {
a, b interface{}
equal bool
}{
{"a", "a", true},
{"a", "b", false},
{"a", aws.String(""), false},
{"a", nil, false},
{"a", aws.String("a"), true},
{(*bool)(nil), (*bool)(nil), true},
{(*bool)(nil), (*string)(nil), false},
{nil, nil, true},
}
for i, c := range cases {
assert.Equal(t, c.equal, awsutil.DeepEqual(c.a, c.b), "%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
}
}

View file

@ -0,0 +1,142 @@
package awsutil_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
type Struct struct {
A []Struct
z []Struct
B *Struct
D *Struct
C string
E map[string]string
}
var data = Struct{
A: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
z: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
B: &Struct{B: &Struct{C: "terminal"}, D: &Struct{C: "terminal2"}},
C: "initial",
}
var data2 = Struct{A: []Struct{
{A: []Struct{{C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}}},
{A: []Struct{{C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}}},
}}
func TestValueAtPathSuccess(t *testing.T) {
var testCases = []struct {
expect []interface{}
data interface{}
path string
}{
{[]interface{}{"initial"}, data, "C"},
{[]interface{}{"value1"}, data, "A[0].C"},
{[]interface{}{"value2"}, data, "A[1].C"},
{[]interface{}{"value3"}, data, "A[2].C"},
{[]interface{}{"value3"}, data, "a[2].c"},
{[]interface{}{"value3"}, data, "A[-1].C"},
{[]interface{}{"value1", "value2", "value3"}, data, "A[].C"},
{[]interface{}{"terminal"}, data, "B . B . C"},
{[]interface{}{"initial"}, data, "A.D.X || C"},
{[]interface{}{"initial"}, data, "A[0].B || C"},
{[]interface{}{
Struct{A: []Struct{{C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}}},
Struct{A: []Struct{{C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}}},
}, data2, "A"},
}
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
}
}
func TestValueAtPathFailure(t *testing.T) {
var testCases = []struct {
expect []interface{}
errContains string
data interface{}
path string
}{
{nil, "", data, "C.x"},
{nil, "SyntaxError: Invalid token: tDot", data, ".x"},
{nil, "", data, "X.Y.Z"},
{nil, "", data, "A[100].C"},
{nil, "", data, "A[3].C"},
{nil, "", data, "B.B.C.Z"},
{nil, "", data, "z[-1].C"},
{nil, "", nil, "A.B.C"},
{[]interface{}{}, "", Struct{}, "A"},
{nil, "", data, "A[0].B.C"},
{nil, "", data, "D"},
}
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
if c.errContains != "" {
assert.Contains(t, err.Error(), c.errContains, "case %d, expected error, %s", i, c.path)
continue
} else {
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
}
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
}
}
func TestSetValueAtPathSuccess(t *testing.T) {
var s Struct
awsutil.SetValueAtPath(&s, "C", "test1")
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
assert.Equal(t, "test1", s.C)
assert.Equal(t, "test2", s.B.B.C)
assert.Equal(t, "test3", s.B.D.C)
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
assert.Equal(t, "test0", s.B.B.C)
assert.Equal(t, "test0", s.B.D.C)
var s2 Struct
awsutil.SetValueAtPath(&s2, "b.b.c", "test0")
assert.Equal(t, "test0", s2.B.B.C)
awsutil.SetValueAtPath(&s2, "A", []Struct{{}})
assert.Equal(t, []Struct{{}}, s2.A)
str := "foo"
s3 := Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", str)
assert.Equal(t, "foo", s3.B.B.C)
s3 = Struct{B: &Struct{B: &Struct{C: str}}}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", &str)
assert.Equal(t, "foo", s3.B.B.C)
var s4 struct{ Name *string }
awsutil.SetValueAtPath(&s4, "Name", str)
assert.Equal(t, str, *s4.Name)
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
s4 = struct{ Name *string }{Name: &str}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", &str)
assert.Equal(t, str, *s4.Name)
}

View file

@ -2,7 +2,6 @@ package client
import (
"fmt"
"net/http/httputil"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
@ -46,7 +45,7 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op
svc := &Client{
Config: cfg,
ClientInfo: info,
Handlers: handlers,
Handlers: handlers.Copy(),
}
switch retryer, ok := cfg.Retryer.(request.Retryer); {
@ -86,61 +85,6 @@ func (c *Client) AddDebugHandlers() {
return
}
c.Handlers.Send.PushFront(logRequest)
c.Handlers.Send.PushBack(logResponse)
}
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
const logReqErrMsg = `DEBUG ERROR: Request %s/%s:
---[ REQUEST DUMP ERROR ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
if logBody {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.ResetBody()
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
const logRespErrMsg = `DEBUG ERROR: Response %s/%s:
---[ RESPONSE DUMP ERROR ]-----------------------------
%s
-----------------------------------------------------`
func logResponse(r *request.Request) {
var msg = "no response data"
if r.HTTPResponse != nil {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, err := httputil.DumpResponse(r.HTTPResponse, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logRespErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
msg = string(dumpedBody)
} else if r.Error != nil {
msg = r.Error.Error()
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ClientInfo.ServiceName, r.Operation.Name, msg))
c.Handlers.Send.PushFrontNamed(request.NamedHandler{Name: "awssdk.client.LogRequest", Fn: logRequest})
c.Handlers.Send.PushBackNamed(request.NamedHandler{Name: "awssdk.client.LogResponse", Fn: logResponse})
}

View file

@ -0,0 +1,78 @@
package client
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
)
func pushBackTestHandler(name string, list *request.HandlerList) *bool {
called := false
(*list).PushBackNamed(request.NamedHandler{
Name: name,
Fn: func(r *request.Request) {
called = true
},
})
return &called
}
func pushFrontTestHandler(name string, list *request.HandlerList) *bool {
called := false
(*list).PushFrontNamed(request.NamedHandler{
Name: name,
Fn: func(r *request.Request) {
called = true
},
})
return &called
}
func TestNewClient_CopyHandlers(t *testing.T) {
handlers := request.Handlers{}
firstCalled := pushBackTestHandler("first", &handlers.Send)
secondCalled := pushBackTestHandler("second", &handlers.Send)
var clientHandlerCalled *bool
c := New(aws.Config{}, metadata.ClientInfo{}, handlers,
func(c *Client) {
clientHandlerCalled = pushFrontTestHandler("client handler", &c.Handlers.Send)
},
)
if e, a := 2, handlers.Send.Len(); e != a {
t.Errorf("expect %d original handlers, got %d", e, a)
}
if e, a := 3, c.Handlers.Send.Len(); e != a {
t.Errorf("expect %d client handlers, got %d", e, a)
}
handlers.Send.Run(nil)
if !*firstCalled {
t.Errorf("expect first handler to of been called")
}
*firstCalled = false
if !*secondCalled {
t.Errorf("expect second handler to of been called")
}
*secondCalled = false
if *clientHandlerCalled {
t.Errorf("expect client handler to not of been called, but was")
}
c.Handlers.Send.Run(nil)
if !*firstCalled {
t.Errorf("expect client's first handler to of been called")
}
if !*secondCalled {
t.Errorf("expect client's second handler to of been called")
}
if !*clientHandlerCalled {
t.Errorf("expect client's client handler to of been called")
}
}

View file

@ -54,6 +54,12 @@ func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
// ShouldRetry returns true if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable != nil {
return *r.Retryable
}
if r.HTTPResponse.StatusCode >= 500 {
return true
}

102
vendor/github.com/aws/aws-sdk-go/aws/client/logger.go generated vendored Normal file
View file

@ -0,0 +1,102 @@
package client
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http/httputil"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
const logReqErrMsg = `DEBUG ERROR: Request %s/%s:
---[ REQUEST DUMP ERROR ]-----------------------------
%s
------------------------------------------------------`
type logWriter struct {
// Logger is what we will use to log the payload of a response.
Logger aws.Logger
// buf stores the contents of what has been read
buf *bytes.Buffer
}
func (logger *logWriter) Write(b []byte) (int, error) {
return logger.buf.Write(b)
}
type teeReaderCloser struct {
// io.Reader will be a tee reader that is used during logging.
// This structure will read from a body and write the contents to a logger.
io.Reader
// Source is used just to close when we are done reading.
Source io.ReadCloser
}
func (reader *teeReaderCloser) Close() error {
return reader.Source.Close()
}
func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
if logBody {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.ResetBody()
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
const logRespErrMsg = `DEBUG ERROR: Response %s/%s:
---[ RESPONSE DUMP ERROR ]-----------------------------
%s
-----------------------------------------------------`
func logResponse(r *request.Request) {
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
r.HTTPResponse.Body = &teeReaderCloser{
Reader: io.TeeReader(r.HTTPResponse.Body, lw),
Source: r.HTTPResponse.Body,
}
handlerFn := func(req *request.Request) {
body, err := httputil.DumpResponse(req.HTTPResponse, false)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
b, err := ioutil.ReadAll(lw.buf)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
lw.Logger.Log(fmt.Sprintf(logRespMsg, req.ClientInfo.ServiceName, req.Operation.Name, string(body)))
if req.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) {
lw.Logger.Log(string(b))
}
}
r.Handlers.Unmarshal.PushBack(handlerFn)
r.Handlers.UnmarshalError.PushBack(handlerFn)
}

View file

@ -0,0 +1,57 @@
package client
import (
"bytes"
"io"
"testing"
)
type mockCloser struct {
closed bool
}
func (closer *mockCloser) Read(b []byte) (int, error) {
return 0, io.EOF
}
func (closer *mockCloser) Close() error {
closer.closed = true
return nil
}
func TestTeeReaderCloser(t *testing.T) {
expected := "FOO"
buf := bytes.NewBuffer([]byte(expected))
lw := bytes.NewBuffer(nil)
c := &mockCloser{}
closer := teeReaderCloser{
io.TeeReader(buf, lw),
c,
}
b := make([]byte, len(expected))
_, err := closer.Read(b)
closer.Close()
if expected != lw.String() {
t.Errorf("Expected %q, but received %q", expected, lw.String())
}
if err != nil {
t.Errorf("Expected 'nil', but received %v", err)
}
if !c.closed {
t.Error("Expected 'true', but received 'false'")
}
}
func TestLogWriter(t *testing.T) {
expected := "FOO"
lw := &logWriter{nil, bytes.NewBuffer(nil)}
lw.Write([]byte(expected))
if expected != lw.buf.String() {
t.Errorf("Expected %q, but received %q", expected, lw.buf.String())
}
}

View file

@ -22,9 +22,9 @@ type RequestRetryer interface{}
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
// sess, err := session.NewSession(&aws.Config{
// sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
// })
// }))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, &aws.Config{
@ -53,6 +53,13 @@ type Config struct {
// to use based on region.
EndpointResolver endpoints.Resolver
// EnforceShouldRetryCheck is used in the AfterRetryHandler to always call
// ShouldRetry regardless of whether or not if request.Retryable is set.
// This will utilize ShouldRetry method of custom retryers. If EnforceShouldRetryCheck
// is not set, then ShouldRetry will only be called if request.Retryable is nil.
// Proper handling of the request.Retryable field is important when setting this field.
EnforceShouldRetryCheck *bool
// The region to send requests to. This parameter is required and must
// be configured globally or on a per-client basis unless otherwise
// noted. A full list of regions is found in the "Regions and Endpoints"
@ -154,7 +161,8 @@ type Config struct {
// the EC2Metadata overriding the timeout for default credentials chain.
//
// Example:
// sess, err := session.NewSession(aws.NewConfig().WithEC2MetadataDiableTimeoutOverride(true))
// sess := session.Must(session.NewSession(aws.NewConfig()
// .WithEC2MetadataDiableTimeoutOverride(true)))
//
// svc := s3.New(sess)
//
@ -174,7 +182,7 @@ type Config struct {
//
// Only supported with.
//
// sess, err := session.NewSession()
// sess := session.Must(session.NewSession())
//
// svc := s3.New(sess, &aws.Config{
// UseDualStack: aws.Bool(true),
@ -186,13 +194,19 @@ type Config struct {
// request delays. This value should only be used for testing. To adjust
// the delay of a request see the aws/client.DefaultRetryer and
// aws/request.Retryer.
//
// SleepDelay will prevent any Context from being used for canceling retry
// delay of an API operation. It is recommended to not use SleepDelay at all
// and specify a Retryer instead.
SleepDelay func(time.Duration)
// DisableRestProtocolURICleaning will not clean the URL path when making rest protocol requests.
// Will default to false. This would only be used for empty directory names in s3 requests.
//
// Example:
// sess, err := session.NewSession(&aws.Config{DisableRestProtocolURICleaning: aws.Bool(true))
// sess := session.Must(session.NewSession(&aws.Config{
// DisableRestProtocolURICleaning: aws.Bool(true),
// }))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
@ -207,9 +221,9 @@ type Config struct {
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
// sess, err := session.NewSession(aws.NewConfig().
// sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
// )
// ))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, aws.NewConfig().
@ -436,6 +450,10 @@ func mergeInConfig(dst *Config, other *Config) {
if other.DisableRestProtocolURICleaning != nil {
dst.DisableRestProtocolURICleaning = other.DisableRestProtocolURICleaning
}
if other.EnforceShouldRetryCheck != nil {
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
}
}
// Copy will return a shallow copy of the Config object. If any additional

86
vendor/github.com/aws/aws-sdk-go/aws/config_test.go generated vendored Normal file
View file

@ -0,0 +1,86 @@
package aws
import (
"net/http"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var testCredentials = credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
var copyTestConfig = Config{
Credentials: testCredentials,
Endpoint: String("CopyTestEndpoint"),
Region: String("COPY_TEST_AWS_REGION"),
DisableSSL: Bool(true),
HTTPClient: http.DefaultClient,
LogLevel: LogLevel(LogDebug),
Logger: NewDefaultLogger(),
MaxRetries: Int(3),
DisableParamValidation: Bool(true),
DisableComputeChecksums: Bool(true),
S3ForcePathStyle: Bool(true),
}
func TestCopy(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if !reflect.DeepEqual(*got, want) {
t.Errorf("Copy() = %+v", got)
t.Errorf(" want %+v", want)
}
got.Region = String("other")
if got.Region == want.Region {
t.Errorf("Expect setting copy values not not reflect in source")
}
}
func TestCopyReturnsNewInstance(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if got == &want {
t.Errorf("Copy() = %p; want different instance as source %p", got, &want)
}
}
var mergeTestZeroValueConfig = Config{}
var mergeTestConfig = Config{
Credentials: testCredentials,
Endpoint: String("MergeTestEndpoint"),
Region: String("MERGE_TEST_AWS_REGION"),
DisableSSL: Bool(true),
HTTPClient: http.DefaultClient,
LogLevel: LogLevel(LogDebug),
Logger: NewDefaultLogger(),
MaxRetries: Int(10),
DisableParamValidation: Bool(true),
DisableComputeChecksums: Bool(true),
S3ForcePathStyle: Bool(true),
}
var mergeTests = []struct {
cfg *Config
in *Config
want *Config
}{
{&Config{}, nil, &Config{}},
{&Config{}, &mergeTestZeroValueConfig, &Config{}},
{&Config{}, &mergeTestConfig, &mergeTestConfig},
}
func TestMerge(t *testing.T) {
for i, tt := range mergeTests {
got := tt.cfg.Copy()
got.MergeIn(tt.in)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config %d %+v", i, tt.cfg)
t.Errorf(" Merge(%+v)", tt.in)
t.Errorf(" got %+v", got)
t.Errorf(" want %+v", tt.want)
}
}
}

71
vendor/github.com/aws/aws-sdk-go/aws/context.go generated vendored Normal file
View file

@ -0,0 +1,71 @@
package aws
import (
"time"
)
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext"
// API methods with Go v1.6 and a Context type such as golang.org/x/net/context.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
Value(key interface{}) interface{}
}
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

41
vendor/github.com/aws/aws-sdk-go/aws/context_1_6.go generated vendored Normal file
View file

@ -0,0 +1,41 @@
// +build !go1.7
package aws
import "time"
// An emptyCtx is a copy of the the Go 1.7 context.emptyCtx type. This
// is copied to provide a 1.6 and 1.5 safe version of context that is compatible
// with Go 1.7's Context.
//
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
// struct{}, since vars of this type must have distinct addresses.
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*emptyCtx) Done() <-chan struct{} {
return nil
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
func (e *emptyCtx) String() string {
switch e {
case backgroundCtx:
return "aws.BackgroundContext"
}
return "unknown empty Context"
}
var (
backgroundCtx = new(emptyCtx)
)

9
vendor/github.com/aws/aws-sdk-go/aws/context_1_7.go generated vendored Normal file
View file

@ -0,0 +1,9 @@
// +build go1.7
package aws
import "context"
var (
backgroundCtx = context.Background()
)

37
vendor/github.com/aws/aws-sdk-go/aws/context_test.go generated vendored Normal file
View file

@ -0,0 +1,37 @@
package aws_test
import (
"fmt"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestSleepWithContext(t *testing.T) {
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
err := aws.SleepWithContext(ctx, 1*time.Millisecond)
if err != nil {
t.Errorf("expect context to not be canceled, got %v", err)
}
}
func TestSleepWithContext_Canceled(t *testing.T) {
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
expectErr := fmt.Errorf("context canceled")
ctx.Error = expectErr
close(ctx.DoneCh)
err := aws.SleepWithContext(ctx, 1*time.Millisecond)
if err == nil {
t.Fatalf("expect error, did not get one")
}
if e, a := expectErr, err; e != a {
t.Errorf("expect %v error, got %v", e, a)
}
}

View file

@ -0,0 +1,437 @@
package aws
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var testCasesStringSlice = [][]string{
{"a", "b", "c", "d", "e"},
{"a", "b", "", "", "e"},
}
func TestStringSlice(t *testing.T) {
for idx, in := range testCasesStringSlice {
if in == nil {
continue
}
out := StringSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := StringValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesStringValueSlice = [][]*string{
{String("a"), String("b"), nil, String("c")},
}
func TestStringValueSlice(t *testing.T) {
for idx, in := range testCasesStringValueSlice {
if in == nil {
continue
}
out := StringValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := StringSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesStringMap = []map[string]string{
{"a": "1", "b": "2", "c": "3"},
}
func TestStringMap(t *testing.T) {
for idx, in := range testCasesStringMap {
if in == nil {
continue
}
out := StringMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := StringValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesBoolSlice = [][]bool{
{true, true, false, false},
}
func TestBoolSlice(t *testing.T) {
for idx, in := range testCasesBoolSlice {
if in == nil {
continue
}
out := BoolSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := BoolValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesBoolValueSlice = [][]*bool{}
func TestBoolValueSlice(t *testing.T) {
for idx, in := range testCasesBoolValueSlice {
if in == nil {
continue
}
out := BoolValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := BoolSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesBoolMap = []map[string]bool{
{"a": true, "b": false, "c": true},
}
func TestBoolMap(t *testing.T) {
for idx, in := range testCasesBoolMap {
if in == nil {
continue
}
out := BoolMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := BoolValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesIntSlice = [][]int{
{1, 2, 3, 4},
}
func TestIntSlice(t *testing.T) {
for idx, in := range testCasesIntSlice {
if in == nil {
continue
}
out := IntSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := IntValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesIntValueSlice = [][]*int{}
func TestIntValueSlice(t *testing.T) {
for idx, in := range testCasesIntValueSlice {
if in == nil {
continue
}
out := IntValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := IntSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesIntMap = []map[string]int{
{"a": 3, "b": 2, "c": 1},
}
func TestIntMap(t *testing.T) {
for idx, in := range testCasesIntMap {
if in == nil {
continue
}
out := IntMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := IntValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesInt64Slice = [][]int64{
{1, 2, 3, 4},
}
func TestInt64Slice(t *testing.T) {
for idx, in := range testCasesInt64Slice {
if in == nil {
continue
}
out := Int64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := Int64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesInt64ValueSlice = [][]*int64{}
func TestInt64ValueSlice(t *testing.T) {
for idx, in := range testCasesInt64ValueSlice {
if in == nil {
continue
}
out := Int64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := Int64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesInt64Map = []map[string]int64{
{"a": 3, "b": 2, "c": 1},
}
func TestInt64Map(t *testing.T) {
for idx, in := range testCasesInt64Map {
if in == nil {
continue
}
out := Int64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := Int64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesFloat64Slice = [][]float64{
{1, 2, 3, 4},
}
func TestFloat64Slice(t *testing.T) {
for idx, in := range testCasesFloat64Slice {
if in == nil {
continue
}
out := Float64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := Float64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesFloat64ValueSlice = [][]*float64{}
func TestFloat64ValueSlice(t *testing.T) {
for idx, in := range testCasesFloat64ValueSlice {
if in == nil {
continue
}
out := Float64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := Float64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesFloat64Map = []map[string]float64{
{"a": 3, "b": 2, "c": 1},
}
func TestFloat64Map(t *testing.T) {
for idx, in := range testCasesFloat64Map {
if in == nil {
continue
}
out := Float64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := Float64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesTimeSlice = [][]time.Time{
{time.Now(), time.Now().AddDate(100, 0, 0)},
}
func TestTimeSlice(t *testing.T) {
for idx, in := range testCasesTimeSlice {
if in == nil {
continue
}
out := TimeSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := TimeValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesTimeValueSlice = [][]*time.Time{}
func TestTimeValueSlice(t *testing.T) {
for idx, in := range testCasesTimeValueSlice {
if in == nil {
continue
}
out := TimeValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := TimeSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesTimeMap = []map[string]time.Time{
{"a": time.Now().AddDate(-100, 0, 0), "b": time.Now()},
}
func TestTimeMap(t *testing.T) {
for idx, in := range testCasesTimeMap {
if in == nil {
continue
}
out := TimeMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := TimeValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}

View file

@ -27,7 +27,7 @@ type lener interface {
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
//
// The Content-Length will only be aded to the request if the length of the body
// The Content-Length will only be added to the request if the length of the body
// is greater than 0. If the body is empty or the current `Content-Length`
// header is <= 0, the header will also be stripped.
var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) {
@ -71,8 +71,8 @@ var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// ValidateReqSigHandler is a request handler to ensure that the request's
// signature doesn't expire before it is sent. This can happen when a request
// is built and signed signficantly before it is sent. Or significant delays
// occur whne retrying requests that would cause the signature to expire.
// is built and signed significantly before it is sent. Or significant delays
// occur when retrying requests that would cause the signature to expire.
var ValidateReqSigHandler = request.NamedHandler{
Name: "core.ValidateReqSigHandler",
Fn: func(r *request.Request) {
@ -98,44 +98,79 @@ var ValidateReqSigHandler = request.NamedHandler{
}
// SendHandler is a request handler to send service request using HTTP client.
var SendHandler = request.NamedHandler{Name: "core.SendHandler", Fn: func(r *request.Request) {
var err error
r.HTTPResponse, err = r.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil {
// Prevent leaking if an HTTPResponse was returned. Clean up
// the body.
if r.HTTPResponse != nil {
r.HTTPResponse.Body.Close()
var SendHandler = request.NamedHandler{
Name: "core.SendHandler",
Fn: func(r *request.Request) {
sender := sendFollowRedirects
if r.DisableFollowRedirects {
sender = sendWithoutFollowRedirects
}
// Capture the case where url.Error is returned for error processing
// response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other url redirect errors will
// comeback in a similar method.
if e, ok := err.(*url.Error); ok && e.Err != nil {
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.HTTPResponse = &http.Response{
StatusCode: int(code),
Status: http.StatusText(int(code)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
return
}
var err error
r.HTTPResponse, err = sender(r)
if err != nil {
handleSendError(r, err)
}
if r.HTTPResponse == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
},
}
func sendFollowRedirects(r *request.Request) (*http.Response, error) {
return r.Config.HTTPClient.Do(r.HTTPRequest)
}
func sendWithoutFollowRedirects(r *request.Request) (*http.Response, error) {
transport := r.Config.HTTPClient.Transport
if transport == nil {
transport = http.DefaultTransport
}
return transport.RoundTrip(r.HTTPRequest)
}
func handleSendError(r *request.Request, err error) {
// Prevent leaking if an HTTPResponse was returned. Clean up
// the body.
if r.HTTPResponse != nil {
r.HTTPResponse.Body.Close()
}
// Capture the case where url.Error is returned for error processing
// response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other URL redirect errors will
// comeback in a similar method.
if e, ok := err.(*url.Error); ok && e.Err != nil {
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
StatusCode: int(code),
Status: http.StatusText(int(code)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
return
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = aws.Bool(true) // network errors are retryable
}
}}
if r.HTTPResponse == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = aws.Bool(true) // network errors are retryable
// Override the error with a context canceled error, if that was canceled.
ctx := r.Context()
select {
case <-ctx.Done():
r.Error = awserr.New(request.CanceledErrorCode,
"request context canceled", ctx.Err())
r.Retryable = aws.Bool(false)
default:
}
}
// ValidateResponseHandler is a request handler to validate service response.
var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) {
@ -150,13 +185,22 @@ var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseH
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil {
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
r.Retryable = aws.Bool(r.ShouldRetry(r))
}
if r.WillRetry() {
r.RetryDelay = r.RetryRules(r)
r.Config.SleepDelay(r.RetryDelay)
if sleepFn := r.Config.SleepDelay; sleepFn != nil {
// Support SleepDelay for backwards compatibility and testing
sleepFn(r.RetryDelay)
} else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil {
r.Error = awserr.New(request.CanceledErrorCode,
"request context canceled", err)
r.Retryable = aws.Bool(false)
return
}
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to

View file

@ -0,0 +1,355 @@
package corehandlers_test
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := awstesting.NewClient(aws.NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := awstesting.NewClient()
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, aws.ErrMissingRegion, err)
}
type mockCredsProvider struct {
expired bool
retrieveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retrieveCalled = true
return credentials.Value{ProviderName: "mockCredsProvider"}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := awstesting.NewClient(&aws.Config{
Credentials: credentials.NewCredentials(credProvider),
MaxRetries: aws.Int(1),
})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *request.Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400, Body: ioutil.NopCloser(bytes.NewBuffer([]byte{}))}
})
svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retrieveCalled)
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retrieveCalled)
}
func TestAfterRetryWithContextCanceled(t *testing.T) {
c := awstesting.NewClient()
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{}, 0)}
req.SetContext(ctx)
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
close(ctx.DoneCh)
ctx.Error = fmt.Errorf("context canceled")
corehandlers.AfterRetryHandler.Fn(req)
if req.Error == nil {
t.Fatalf("expect error but didn't receive one")
}
aerr := req.Error.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q, error code got %q", e, a)
}
}
func TestAfterRetryWithContext(t *testing.T) {
c := awstesting.NewClient()
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{}, 0)}
req.SetContext(ctx)
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
corehandlers.AfterRetryHandler.Fn(req)
if req.Error != nil {
t.Fatalf("expect no error, got %v", req.Error)
}
if e, a := 1, req.RetryCount; e != a {
t.Errorf("expect retry count to be %d, got %d", e, a)
}
}
func TestSendWithContextCanceled(t *testing.T) {
c := awstesting.NewClient(&aws.Config{
SleepDelay: func(dur time.Duration) {
t.Errorf("SleepDelay should not be called")
},
})
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{}, 0)}
req.SetContext(ctx)
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
close(ctx.DoneCh)
ctx.Error = fmt.Errorf("context canceled")
corehandlers.SendHandler.Fn(req)
if req.Error == nil {
t.Fatalf("expect error but didn't receive one")
}
aerr := req.Error.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q, error code got %q", e, a)
}
}
type testSendHandlerTransport struct{}
func (t *testSendHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("mock error")
}
func TestSendHandlerError(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{
HTTPClient: &http.Client{
Transport: &testSendHandlerTransport{},
},
})
svc.Handlers.Clear()
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
r.Send()
assert.Error(t, r.Error)
assert.NotNil(t, r.HTTPResponse)
}
func TestSendWithoutFollowRedirects(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/original":
w.Header().Set("Location", "/redirected")
w.WriteHeader(301)
case "/redirected":
t.Fatalf("expect not to redirect, but was")
}
}))
svc := awstesting.NewClient(&aws.Config{
DisableSSL: aws.Bool(true),
Endpoint: aws.String(server.URL),
})
svc.Handlers.Clear()
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{
Name: "Operation",
HTTPPath: "/original",
}, nil, nil)
r.DisableFollowRedirects = true
err := r.Send()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := 301, r.HTTPResponse.StatusCode; e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
}
func TestValidateReqSigHandler(t *testing.T) {
cases := []struct {
Req *request.Request
Resign bool
}{
{
Req: &request.Request{
Config: aws.Config{Credentials: credentials.AnonymousCredentials},
Time: time.Now().Add(-15 * time.Minute),
},
Resign: false,
},
{
Req: &request.Request{
Time: time.Now().Add(-15 * time.Minute),
},
Resign: true,
},
{
Req: &request.Request{
Time: time.Now().Add(-1 * time.Minute),
},
Resign: false,
},
}
for i, c := range cases {
resigned := false
c.Req.Handlers.Sign.PushBack(func(r *request.Request) {
resigned = true
})
corehandlers.ValidateReqSigHandler.Fn(c.Req)
assert.NoError(t, c.Req.Error, "%d, expect no error", i)
assert.Equal(t, c.Resign, resigned, "%d, expected resigning to match", i)
}
}
func setupContentLengthTestServer(t *testing.T, hasContentLength bool, contentLength int64) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := r.Header["Content-Length"]
assert.Equal(t, hasContentLength, ok, "expect content length to be set, %t", hasContentLength)
if hasContentLength {
assert.Equal(t, contentLength, r.ContentLength)
}
b, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
r.Body.Close()
authHeader := r.Header.Get("Authorization")
if hasContentLength {
assert.Contains(t, authHeader, "content-length")
} else {
assert.NotContains(t, authHeader, "content-length")
}
assert.Equal(t, contentLength, int64(len(b)))
}))
return server
}
func TestBuildContentLength_ZeroBody(t *testing.T) {
server := setupContentLengthTestServer(t, false, 0)
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
})
_, err := svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
})
assert.NoError(t, err)
}
func TestBuildContentLength_NegativeBody(t *testing.T) {
server := setupContentLengthTestServer(t, false, 0)
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
})
req, _ := svc.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
})
req.HTTPRequest.Header.Set("Content-Length", "-1")
assert.NoError(t, req.Send())
}
func TestBuildContentLength_WithBody(t *testing.T) {
server := setupContentLengthTestServer(t, true, 1024)
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
})
_, err := svc.PutObject(&s3.PutObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
Body: bytes.NewReader(make([]byte, 1024)),
})
assert.NoError(t, err)
}

View file

@ -0,0 +1,254 @@
package corehandlers_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/stretchr/testify/require"
)
var testSvc = func() *client.Client {
s := &client.Client{
Config: aws.Config{},
ClientInfo: metadata.ClientInfo{
ServiceName: "mock-service",
APIVersion: "2015-01-01",
},
}
return s
}()
type StructShape struct {
_ struct{} `type:"structure"`
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
}
func (s *StructShape) Validate() error {
invalidParams := request.ErrInvalidParams{Context: "StructShape"}
if s.RequiredList == nil {
invalidParams.Add(request.NewErrParamRequired("RequiredList"))
}
if s.RequiredMap == nil {
invalidParams.Add(request.NewErrParamRequired("RequiredMap"))
}
if s.RequiredBool == nil {
invalidParams.Add(request.NewErrParamRequired("RequiredBool"))
}
if s.RequiredList != nil {
for i, v := range s.RequiredList {
if v == nil {
continue
}
if err := v.Validate(); err != nil {
invalidParams.AddNested(fmt.Sprintf("%s[%v]", "RequiredList", i), err.(request.ErrInvalidParams))
}
}
}
if s.RequiredMap != nil {
for i, v := range s.RequiredMap {
if v == nil {
continue
}
if err := v.Validate(); err != nil {
invalidParams.AddNested(fmt.Sprintf("%s[%v]", "RequiredMap", i), err.(request.ErrInvalidParams))
}
}
}
if s.OptionalStruct != nil {
if err := s.OptionalStruct.Validate(); err != nil {
invalidParams.AddNested("OptionalStruct", err.(request.ErrInvalidParams))
}
}
if invalidParams.Len() > 0 {
return invalidParams
}
return nil
}
type ConditionalStructShape struct {
_ struct{} `type:"structure"`
Name *string `required:"true"`
}
func (s *ConditionalStructShape) Validate() error {
invalidParams := request.ErrInvalidParams{Context: "ConditionalStructShape"}
if s.Name == nil {
invalidParams.Add(request.NewErrParamRequired("Name"))
}
if invalidParams.Len() > 0 {
return invalidParams
}
return nil
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.NoError(t, req.Error)
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation error(s) found.", req.Error.(awserr.Error).Message())
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
assert.Len(t, errs, 3)
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredList.", errs[0].Error())
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredMap.", errs[1].Error())
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredBool.", errs[2].Error())
assert.Equal(t, "InvalidParameter: 3 validation error(s) found.\n- missing required field, StructShape.RequiredList.\n- missing required field, StructShape.RequiredMap.\n- missing required field, StructShape.RequiredBool.\n", req.Error.Error())
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{},
}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation error(s) found.", req.Error.(awserr.Error).Message())
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
assert.Len(t, errs, 3)
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredList[0].Name.", errs[0].Error())
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.RequiredMap[key2].Name.", errs[1].Error())
assert.Equal(t, "ParamRequiredError: missing required field, StructShape.OptionalStruct.Name.", errs[2].Error())
}
type testInput struct {
StringField *string `min:"5"`
ListField []string `min:"3"`
MapField map[string]string `min:"4"`
}
func (s testInput) Validate() error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
if s.StringField != nil && len(*s.StringField) < 5 {
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
}
if s.ListField != nil && len(s.ListField) < 3 {
invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
}
if s.MapField != nil && len(s.MapField) < 4 {
invalidParams.Add(request.NewErrParamMinLen("MapField", 4))
}
if invalidParams.Len() > 0 {
return invalidParams
}
return nil
}
var testsFieldMin = []struct {
err awserr.Error
in testInput
}{
{
err: func() awserr.Error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
return invalidParams
}(),
in: testInput{StringField: aws.String("abcd")},
},
{
err: func() awserr.Error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
return invalidParams
}(),
in: testInput{StringField: aws.String("abcd"), ListField: []string{"a", "b"}},
},
{
err: func() awserr.Error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
invalidParams.Add(request.NewErrParamMinLen("MapField", 4))
return invalidParams
}(),
in: testInput{StringField: aws.String("abcd"), ListField: []string{"a", "b"}, MapField: map[string]string{"a": "a", "b": "b"}},
},
{
err: nil,
in: testInput{StringField: aws.String("abcde"),
ListField: []string{"a", "b", "c"}, MapField: map[string]string{"a": "a", "b": "b", "c": "c", "d": "d"}},
},
}
func TestValidateFieldMinParameter(t *testing.T) {
for i, c := range testsFieldMin {
req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
corehandlers.ValidateParametersHandler.Fn(req)
assert.Equal(t, c.err, req.Error, "%d case failed", i)
}
}
func BenchmarkValidateAny(b *testing.B) {
input := &kinesis.PutRecordsInput{
StreamName: aws.String("stream"),
}
for i := 0; i < 100; i++ {
record := &kinesis.PutRecordsRequestEntry{
Data: make([]byte, 10000),
PartitionKey: aws.String("partition"),
}
input.Records = append(input.Records, record)
}
req, _ := kinesis.New(unit.Session).PutRecordsRequest(input)
b.ResetTimer()
for i := 0; i < b.N; i++ {
corehandlers.ValidateParametersHandler.Fn(req)
if err := req.Error; err != nil {
b.Fatalf("validation failed: %v", err)
}
}
}

View file

@ -13,7 +13,7 @@ var (
//
// @readonly
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
`no valid providers in chain. Deprecated.
`no valid providers in chain. Deprecated.
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,
nil)
)
@ -39,16 +39,18 @@ var (
// does not return any credentials ChainProvider will return the error
// ErrNoValidProvidersFoundInChain
//
// creds := NewChainCredentials(
// []Provider{
// &EnvProvider{},
// &EC2RoleProvider{
// creds := credentials.NewChainCredentials(
// []credentials.Provider{
// &credentials.EnvProvider{},
// &ec2rolecreds.EC2RoleProvider{
// Client: ec2metadata.New(sess),
// },
// })
//
// // Usage of ChainCredentials with aws.Config
// svc := ec2.New(&aws.Config{Credentials: creds})
// svc := ec2.New(session.Must(session.NewSession(&aws.Config{
// Credentials: creds,
// })))
//
type ChainProvider struct {
Providers []Provider

View file

@ -0,0 +1,154 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type secondStubProvider struct {
creds Value
expired bool
err error
}
func (s *secondStubProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "secondStubProvider"
return s.creds, s.err
}
func (s *secondStubProvider) IsExpired() bool {
return s.expired
}
func TestChainProviderWithNames(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&secondStubProvider{
creds: Value{
AccessKeyID: "AKIF",
SecretAccessKey: "NOSECRET",
SessionToken: "",
},
},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "secondStubProvider", creds.ProviderName, "Expect provider name to match")
// Also check credentials
assert.Equal(t, "AKIF", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "NOSECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestChainProviderIsExpired(t *testing.T) {
stubProvider := &stubProvider{expired: true}
p := &ChainProvider{
Providers: []Provider{
stubProvider,
},
}
assert.True(t, p.IsExpired(), "Expect expired to be true before any Retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
stubProvider.expired = true
assert.True(t, p.IsExpired(), "Expect return of expired provider")
_, err = p.Retrieve()
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
}
func TestChainProviderWithNoProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
errs := []error{
awserr.New("FirstError", "first provider error", nil),
awserr.New("SecondError", "second provider error", nil),
}
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: errs[0]},
&stubProvider{err: errs[1]},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
}
func TestChainProviderWithNoValidProviderWithVerboseEnabled(t *testing.T) {
errs := []error{
awserr.New("FirstError", "first provider error", nil),
awserr.New("SecondError", "second provider error", nil),
}
p := &ChainProvider{
VerboseErrors: true,
Providers: []Provider{
&stubProvider{err: errs[0]},
&stubProvider{err: errs[1]},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t,
awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs),
err,
"Expect no providers error returned")
}

View file

@ -14,7 +14,7 @@
//
// Example of using the environment variable credentials.
//
// creds := NewEnvCredentials()
// creds := credentials.NewEnvCredentials()
//
// // Retrieve the credentials value
// credValue, err := creds.Get()
@ -26,7 +26,7 @@
// This may be helpful to proactively expire credentials and refresh them sooner
// than they would naturally expire on their own.
//
// creds := NewCredentials(&EC2RoleProvider{})
// creds := credentials.NewCredentials(&ec2rolecreds.EC2RoleProvider{})
// creds.Expire()
// credsValue, err := creds.Get()
// // New credentials will be retrieved instead of from cache.
@ -43,7 +43,7 @@
// func (m *MyProvider) Retrieve() (Value, error) {...}
// func (m *MyProvider) IsExpired() bool {...}
//
// creds := NewCredentials(&MyProvider{})
// creds := credentials.NewCredentials(&MyProvider{})
// credValue, err := creds.Get()
//
package credentials
@ -60,7 +60,9 @@ import (
// when making service API calls. For example, when accessing public
// s3 buckets.
//
// svc := s3.New(&aws.Config{Credentials: AnonymousCredentials})
// svc := s3.New(session.Must(session.NewSession(&aws.Config{
// Credentials: credentials.AnonymousCredentials,
// })))
// // Access public S3 buckets.
//
// @readonly
@ -88,7 +90,7 @@ type Value struct {
// The Provider should not need to implement its own mutexes, because
// that will be managed by Credentials.
type Provider interface {
// Refresh returns nil if it successfully retrieved the value.
// Retrieve returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
Retrieve() (Value, error)
@ -97,6 +99,27 @@ type Provider interface {
IsExpired() bool
}
// An ErrorProvider is a stub credentials provider that always returns an error
// this is used by the SDK when construction a known provider is not possible
// due to an error.
type ErrorProvider struct {
// The error to be returned from Retrieve
Err error
// The provider name to set on the Retrieved returned Value
ProviderName string
}
// Retrieve will always return the error that the ErrorProvider was created with.
func (p ErrorProvider) Retrieve() (Value, error) {
return Value{ProviderName: p.ProviderName}, p.Err
}
// IsExpired will always return not expired.
func (p ErrorProvider) IsExpired() bool {
return false
}
// A Expiry provides shared expiration logic to be used by credentials
// providers to implement expiry functionality.
//

View file

@ -0,0 +1,73 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type stubProvider struct {
creds Value
expired bool
err error
}
func (s *stubProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "stubProvider"
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
return s.expired
}
func TestCredentialsGet(t *testing.T) {
c := NewCredentials(&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
expired: true,
})
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.Get()
assert.Equal(t, "provider error", err.(awserr.Error).Code(), "Expected provider error")
}
func TestCredentialsExpire(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
stub.expired = false
assert.True(t, c.IsExpired(), "Expected to start out expired")
c.Expire()
assert.True(t, c.IsExpired(), "Expected to be expired")
c.forceRefresh = false
assert.False(t, c.IsExpired(), "Expected not to be expired")
stub.expired = true
assert.True(t, c.IsExpired(), "Expected to be expired")
}
func TestCredentialsGetWithProviderName(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, creds.ProviderName, "stubProvider", "Expected provider name to match")
}

View file

@ -0,0 +1,159 @@
package ec2rolecreds_test
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
const credsRespTmpl = `{
"Code": "Success",
"Type": "AWS-HMAC",
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s",
"LastUpdated" : "2009-11-23T0:00:00Z"
}`
const credsFailRespTmpl = `{
"Code": "ErrorCode",
"Message": "ErrorMsg",
"LastUpdated": "2009-11-23T0:00:00Z"
}`
func initTestServer(expireOn string, failAssume bool) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/latest/meta-data/iam/security-credentials" {
fmt.Fprintln(w, "RoleName")
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
if failAssume {
fmt.Fprintf(w, credsFailRespTmpl)
} else {
fmt.Fprintf(w, credsRespTmpl, expireOn)
}
} else {
http.Error(w, "bad request", http.StatusBadRequest)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderFailAssume(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", true)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
assert.Error(t, err, "Expect error")
e := err.(awserr.Error)
assert.Equal(t, "ErrorCode", e.Code())
assert.Equal(t, "ErrorMsg", e.Message())
assert.Nil(t, e.OrigErr())
assert.Equal(t, "", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
ExpiryWindow: time.Hour * 1,
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC3RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := p.Retrieve(); err != nil {
b.Fatal(err)
}
}
}

View file

@ -0,0 +1,111 @@
package endpointcreds_test
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)
func TestRetrieveRefreshableCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/path/to/endpoint", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("Accept"))
assert.Equal(t, "else", r.URL.Query().Get("something"))
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
"Token": "TOKEN",
"Expiration": time.Now().Add(1 * time.Hour),
})
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config,
unit.Session.Handlers,
server.URL+"/path/to/endpoint?something=else",
)
creds, err := client.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Equal(t, "TOKEN", creds.SessionToken)
assert.False(t, client.IsExpired())
client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
assert.True(t, client.IsExpired())
}
func TestRetrieveStaticCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
})
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.False(t, client.IsExpired())
}
func TestFailedRetrieveCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(400)
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"Code": "Error",
"Message": "Message",
})
if err != nil {
fmt.Println("failed to write error", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
assert.Error(t, err)
aerr := err.(awserr.Error)
assert.Equal(t, "CredentialsEndpointError", aerr.Code())
assert.Equal(t, "failed to load credentials", aerr.Message())
aerr = aerr.OrigErr().(awserr.Error)
assert.Equal(t, "Error", aerr.Code())
assert.Equal(t, "Message", aerr.Message())
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.True(t, client.IsExpired())
}

View file

@ -29,6 +29,7 @@ var (
// Environment variables used:
//
// * Access Key ID: AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY
//
// * Secret Access Key: AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY
type EnvProvider struct {
retrieved bool

View file

@ -0,0 +1,70 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestEnvProviderRetrieve(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEnvProviderIsExpired(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
assert.True(t, e.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, e.IsExpired(), "Expect creds to not be expired after retrieve.")
}
func TestEnvProviderNoAccessKeyID(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrAccessKeyIDNotFound, err, "ErrAccessKeyIDNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderNoSecretAccessKey(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrSecretAccessKeyNotFound, err, "ErrSecretAccessKeyNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderAlternateNames(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY", "access")
os.Setenv("AWS_SECRET_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expected access key ID")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expected secret access key")
assert.Empty(t, creds.SessionToken, "Expected no token")
}

View file

@ -0,0 +1,116 @@
package credentials
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSharedCredentialsProvider(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderIsExpired(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "example.ini")
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILEAbsPath(t *testing.T) {
os.Clearenv()
wd, err := os.Getwd()
assert.NoError(t, err)
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join(wd, "example.ini"))
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "no_token")
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "no_token"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProviderColonInCredFile(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "with_colon"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func BenchmarkSharedCredentialsProvider(b *testing.B) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
}

View file

@ -0,0 +1,34 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestStaticProviderGet(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
creds, err := s.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no session token")
}
func TestStaticProviderIsExpired(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
assert.False(t, s.IsExpired(), "Expect static credentials to never expire")
}

View file

@ -1,7 +1,81 @@
// Package stscreds are credential Providers to retrieve STS AWS credentials.
//
// STS provides multiple ways to retrieve credentials which can be used when making
// future AWS service API operation calls.
/*
Package stscreds are credential Providers to retrieve STS AWS credentials.
STS provides multiple ways to retrieve credentials which can be used when making
future AWS service API operation calls.
The SDK will ensure that per instance of credentials.Credentials all requests
to refresh the credentials will be synchronized. But, the SDK is unable to
ensure synchronous usage of the AssumeRoleProvider if the value is shared
between multiple Credentials, Sessions or service clients.
Assume Role
To assume an IAM role using STS with the SDK you can create a new Credentials
with the SDKs's stscreds package.
// Initial credentials loaded from SDK's default credential chain. Such as
// the environment, shared credentials (~/.aws/credentials), or EC2 Instance
// Role. These credentials will be used to to make the STS Assume Role API.
sess := session.Must(session.NewSession())
// Create the credentials from AssumeRoleProvider to assume the role
// referenced by the "myRoleARN" ARN.
creds := stscreds.NewCredentials(sess, "myRoleArn")
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess, &aws.Config{Credentials: creds})
Assume Role with static MFA Token
To assume an IAM role with a MFA token you can either specify a MFA token code
directly or provide a function to prompt the user each time the credentials
need to refresh the role's credentials. Specifying the TokenCode should be used
for short lived operations that will not need to be refreshed, and when you do
not want to have direct control over the user provides their MFA token.
With TokenCode the AssumeRoleProvider will be not be able to refresh the role's
credentials.
// Create the credentials from AssumeRoleProvider to assume the role
// referenced by the "myRoleARN" ARN using the MFA token code provided.
creds := stscreds.NewCredentials(sess, "myRoleArn", func(p *stscreds.AssumeRoleProvider) {
p.SerialNumber = aws.String("myTokenSerialNumber")
p.TokenCode = aws.String("00000000")
})
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess, &aws.Config{Credentials: creds})
Assume Role with MFA Token Provider
To assume an IAM role with MFA for longer running tasks where the credentials
may need to be refreshed setting the TokenProvider field of AssumeRoleProvider
will allow the credential provider to prompt for new MFA token code when the
role's credentials need to be refreshed.
The StdinTokenProvider function is available to prompt on stdin to retrieve
the MFA token code from the user. You can also implement custom prompts by
satisfing the TokenProvider function signature.
Using StdinTokenProvider with multiple AssumeRoleProviders, or Credentials will
have undesirable results as the StdinTokenProvider will not be synchronized. A
single Credentials with an AssumeRoleProvider can be shared safely.
// Create the credentials from AssumeRoleProvider to assume the role
// referenced by the "myRoleARN" ARN. Prompting for MFA token from stdin.
creds := stscreds.NewCredentials(sess, "myRoleArn", func(p *stscreds.AssumeRoleProvider) {
p.SerialNumber = aws.String("myTokenSerialNumber")
p.TokenProvider = stscreds.StdinTokenProvider
})
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess, &aws.Config{Credentials: creds})
*/
package stscreds
import (
@ -9,11 +83,31 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
)
// StdinTokenProvider will prompt on stdout and read from stdin for a string value.
// An error is returned if reading from stdin fails.
//
// Use this function go read MFA tokens from stdin. The function makes no attempt
// to make atomic prompts from stdin across multiple gorouties.
//
// Using StdinTokenProvider with multiple AssumeRoleProviders, or Credentials will
// have undesirable results as the StdinTokenProvider will not be synchronized. A
// single Credentials with an AssumeRoleProvider can be shared safely
//
// Will wait forever until something is provided on the stdin.
func StdinTokenProvider() (string, error) {
var v string
fmt.Printf("Assume Role MFA token code: ")
_, err := fmt.Scanln(&v)
return v, err
}
// ProviderName provides a name of AssumeRole provider
const ProviderName = "AssumeRoleProvider"
@ -27,8 +121,15 @@ type AssumeRoler interface {
var DefaultDuration = time.Duration(15) * time.Minute
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
// keeps track of their expiration time. This provider must be used explicitly,
// as it is not included in the credentials chain.
// keeps track of their expiration time.
//
// This credential provider will be used by the SDKs default credential change
// when shared configuration is enabled, and the shared config or shared credentials
// file configure assume role. See Session docs for how to do this.
//
// AssumeRoleProvider does not provide any synchronization and it is not safe
// to share this value across multiple Credentials, Sessions, or service clients
// without also sharing the same Credentials instance.
type AssumeRoleProvider struct {
credentials.Expiry
@ -65,8 +166,23 @@ type AssumeRoleProvider struct {
// assumed requires MFA (that is, if the policy includes a condition that tests
// for MFA). If the role being assumed requires MFA and if the TokenCode value
// is missing or expired, the AssumeRole call returns an "access denied" error.
//
// If SerialNumber is set and neither TokenCode nor TokenProvider are also
// set an error will be returned.
TokenCode *string
// Async method of providing MFA token code for assuming an IAM role with MFA.
// The value returned by the function will be used as the TokenCode in the Retrieve
// call. See StdinTokenProvider for a provider that prompts and reads from stdin.
//
// This token provider will be called when ever the assumed role's
// credentials need to be refreshed when SerialNumber is also set and
// TokenCode is not set.
//
// If both TokenCode and TokenProvider is set, TokenProvider will be used and
// TokenCode is ignored.
TokenProvider func() (string, error)
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
@ -85,6 +201,10 @@ type AssumeRoleProvider struct {
//
// Takes a Config provider to create the STS client. The ConfigProvider is
// satisfied by the session.Session type.
//
// It is safe to share the returned Credentials with multiple Sessions and
// service clients. All access to the credentials and refreshing them
// will be synchronized.
func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: sts.New(c),
@ -103,7 +223,11 @@ func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*As
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// Takes an AssumeRoler which can be satisfiede by the STS client.
// Takes an AssumeRoler which can be satisfied by the STS client.
//
// It is safe to share the returned Credentials with multiple Sessions and
// service clients. All access to the credentials and refreshing them
// will be synchronized.
func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: svc,
@ -139,12 +263,25 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
if p.Policy != nil {
input.Policy = p.Policy
}
if p.SerialNumber != nil && p.TokenCode != nil {
input.SerialNumber = p.SerialNumber
input.TokenCode = p.TokenCode
if p.SerialNumber != nil {
if p.TokenCode != nil {
input.SerialNumber = p.SerialNumber
input.TokenCode = p.TokenCode
} else if p.TokenProvider != nil {
input.SerialNumber = p.SerialNumber
code, err := p.TokenProvider()
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
input.TokenCode = aws.String(code)
} else {
return credentials.Value{ProviderName: ProviderName},
awserr.New("AssumeRoleTokenNotAvailable",
"assume role with MFA enabled, but neither TokenCode nor TokenProvider are set", nil)
}
}
roleOutput, err := p.Client.AssumeRole(input)
roleOutput, err := p.Client.AssumeRole(input)
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}

View file

@ -0,0 +1,150 @@
package stscreds
import (
"fmt"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
)
type stubSTS struct {
TestInput func(*sts.AssumeRoleInput)
}
func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
if s.TestInput != nil {
s.TestInput(input)
}
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
// Just reflect the role arn to the provider.
AccessKeyId: input.RoleArn,
SecretAccessKey: aws.String("assumedSecretAccessKey"),
SessionToken: aws.String("assumedSessionToken"),
Expiration: &expiry,
},
}, nil
}
func TestAssumeRoleProvider(t *testing.T) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func TestAssumeRoleProvider_WithTokenCode(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
TokenCode: aws.String("code"),
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
TokenProvider: func() (string, error) {
return "code", nil
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
TokenProvider: func() (string, error) {
return "", fmt.Errorf("error occurred")
},
}
creds, err := p.Retrieve()
assert.Error(t, err)
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
}
func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
}
creds, err := p.Retrieve()
assert.Error(t, err)
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
}
func BenchmarkAssumeRoleProvider(b *testing.B) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := p.Retrieve(); err != nil {
b.Fatal(err)
}
}
}

View file

@ -10,10 +10,12 @@ package defaults
import (
"fmt"
"net/http"
"net/url"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
@ -56,7 +58,6 @@ func Config() *aws.Config {
WithMaxRetries(aws.UseServiceDefaultRetries).
WithLogger(aws.NewDefaultLogger()).
WithLogLevel(aws.LogOff).
WithSleepDelay(time.Sleep).
WithEndpointResolver(endpoints.DefaultResolver())
}
@ -97,23 +98,51 @@ func CredChain(cfg *aws.Config, handlers request.Handlers) *credentials.Credenti
})
}
// RemoteCredProvider returns a credenitials provider for the default remote
const (
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
ecsCredsProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
)
// RemoteCredProvider returns a credentials provider for the default remote
// endpoints such as EC2 or ECS Roles.
func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider {
ecsCredURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
if u := os.Getenv(httpProviderEnvVar); len(u) > 0 {
return localHTTPCredProvider(cfg, handlers, u)
}
if len(ecsCredURI) > 0 {
return ecsCredProvider(cfg, handlers, ecsCredURI)
if uri := os.Getenv(ecsCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("http://169.254.170.2%s", uri)
return httpCredProvider(cfg, handlers, u)
}
return ec2RoleProvider(cfg, handlers)
}
func ecsCredProvider(cfg aws.Config, handlers request.Handlers, uri string) credentials.Provider {
const host = `169.254.170.2`
func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
var errMsg string
return endpointcreds.NewProviderClient(cfg, handlers,
fmt.Sprintf("http://%s%s", host, uri),
parsed, err := url.Parse(u)
if err != nil {
errMsg = fmt.Sprintf("invalid URL, %v", err)
} else if host := aws.URLHostname(parsed); !(host == "localhost" || host == "127.0.0.1") {
errMsg = fmt.Sprintf("invalid host address, %q, only localhost and 127.0.0.1 are valid.", host)
}
if len(errMsg) > 0 {
if cfg.Logger != nil {
cfg.Logger.Log("Ignoring, HTTP credential provider", errMsg, err)
}
return credentials.ErrorProvider{
Err: awserr.New("CredentialsEndpointError", errMsg, err),
ProviderName: endpointcreds.ProviderName,
}
}
return httpCredProvider(cfg, handlers, u)
}
func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
return endpointcreds.NewProviderClient(cfg, handlers, u,
func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute
},

View file

@ -0,0 +1,88 @@
package defaults
import (
"fmt"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestHTTPCredProvider(t *testing.T) {
cases := []struct {
Host string
Fail bool
}{
{"localhost", false}, {"127.0.0.1", false},
{"www.example.com", true}, {"169.254.170.2", true},
}
defer os.Clearenv()
for i, c := range cases {
u := fmt.Sprintf("http://%s/abc/123", c.Host)
os.Setenv(httpProviderEnvVar, u)
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
t.Fatalf("%d, expect provider not to be nil, but was", i)
}
if c.Fail {
creds, err := provider.Retrieve()
if err == nil {
t.Fatalf("%d, expect error but got none", i)
} else {
aerr := err.(awserr.Error)
if e, a := "CredentialsEndpointError", aerr.Code(); e != a {
t.Errorf("%d, expect %s error code, got %s", i, e, a)
}
}
if e, a := endpointcreds.ProviderName, creds.ProviderName; e != a {
t.Errorf("%d, expect %s provider name got %s", i, e, a)
}
} else {
httpProvider := provider.(*endpointcreds.Provider)
if e, a := u, httpProvider.Client.Endpoint; e != a {
t.Errorf("%d, expect %q endpoint, got %q", i, e, a)
}
}
}
}
func TestECSCredProvider(t *testing.T) {
defer os.Clearenv()
os.Setenv(ecsCredsProviderEnvVar, "/abc/123")
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
httpProvider := provider.(*endpointcreds.Provider)
if httpProvider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
if e, a := "http://169.254.170.2/abc/123", httpProvider.Client.Endpoint; e != a {
t.Errorf("expect %q endpoint, got %q", e, a)
}
}
func TestDefaultEC2RoleProvider(t *testing.T) {
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
ec2Provider := provider.(*ec2rolecreds.EC2RoleProvider)
if ec2Provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
if e, a := "http://169.254.169.254/latest", ec2Provider.Client.Endpoint; e != a {
t.Errorf("expect %q endpoint, got %q", e, a)
}
}

56
vendor/github.com/aws/aws-sdk-go/aws/doc.go generated vendored Normal file
View file

@ -0,0 +1,56 @@
// Package aws provides the core SDK's utilities and shared types. Use this package's
// utilities to simplify setting and reading API operations parameters.
//
// Value and Pointer Conversion Utilities
//
// This package includes a helper conversion utility for each scalar type the SDK's
// API use. These utilities make getting a pointer of the scalar, and dereferencing
// a pointer easier.
//
// Each conversion utility comes in two forms. Value to Pointer and Pointer to Value.
// The Pointer to value will safely dereference the pointer and return its value.
// If the pointer was nil, the scalar's zero value will be returned.
//
// The value to pointer functions will be named after the scalar type. So get a
// *string from a string value use the "String" function. This makes it easy to
// to get pointer of a literal string value, because getting the address of a
// literal requires assigning the value to a variable first.
//
// var strPtr *string
//
// // Without the SDK's conversion functions
// str := "my string"
// strPtr = &str
//
// // With the SDK's conversion functions
// strPtr = aws.String("my string")
//
// // Convert *string to string value
// str = aws.StringValue(strPtr)
//
// In addition to scalars the aws package also includes conversion utilities for
// map and slice for commonly types used in API parameters. The map and slice
// conversion functions use similar naming pattern as the scalar conversion
// functions.
//
// var strPtrs []*string
// var strs []string = []string{"Go", "Gophers", "Go"}
//
// // Convert []string to []*string
// strPtrs = aws.StringSlice(strs)
//
// // Convert []*string to []string
// strs = aws.StringValueSlice(strPtrs)
//
// SDK Default HTTP Client
//
// The SDK will use the http.DefaultClient if a HTTP client is not provided to
// the SDK's Session, or service client constructor. This means that if the
// http.DefaultClient is modified by other components of your application the
// modifications will be picked up by the SDK as well.
//
// In some cases this might be intended, but it is a better practice to create
// a custom HTTP Client to share explicitly through your application. You can
// configure the SDK to use the custom HTTP Client by setting the HTTPClient
// value of the SDK's Config type when creating a Session or service client.
package aws

View file

@ -0,0 +1,242 @@
package ec2metadata_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"path"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
const instanceIdentityDocument = `{
"devpayProductCodes" : null,
"availabilityZone" : "us-east-1d",
"privateIp" : "10.158.112.84",
"version" : "2010-08-31",
"region" : "us-east-1",
"instanceId" : "i-1234567890abcdef0",
"billingProducts" : null,
"instanceType" : "t1.micro",
"accountId" : "123456789012",
"pendingTime" : "2015-11-19T16:32:11Z",
"imageId" : "ami-5fb8c835",
"kernelId" : "aki-919dcaf8",
"ramdiskId" : null,
"architecture" : "x86_64"
}`
const validIamInfo = `{
"Code" : "Success",
"LastUpdated" : "2016-03-17T12:27:32Z",
"InstanceProfileArn" : "arn:aws:iam::123456789012:instance-profile/my-instance-profile",
"InstanceProfileId" : "AIPAABCDEFGHIJKLMN123"
}`
const unsuccessfulIamInfo = `{
"Code" : "Failed",
"LastUpdated" : "2016-03-17T12:27:32Z",
"InstanceProfileArn" : "arn:aws:iam::123456789012:instance-profile/my-instance-profile",
"InstanceProfileId" : "AIPAABCDEFGHIJKLMN123"
}`
func initTestServer(path string, resp string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI != path {
http.Error(w, "not found", http.StatusNotFound)
return
}
w.Write([]byte(resp))
}))
}
func TestEndpoint(t *testing.T) {
c := ec2metadata.New(unit.Session)
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", "testpath"),
}
req := c.NewRequest(op, nil, nil)
assert.Equal(t, "http://169.254.169.254/latest", req.ClientInfo.Endpoint)
assert.Equal(t, "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String())
}
func TestGetMetadata(t *testing.T) {
server := initTestServer(
"/latest/meta-data/some/path",
"success", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetMetadata("some/path")
assert.NoError(t, err)
assert.Equal(t, "success", resp)
}
func TestGetUserData(t *testing.T) {
server := initTestServer(
"/latest/user-data",
"success", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetUserData()
assert.NoError(t, err)
assert.Equal(t, "success", resp)
}
func TestGetUserData_Error(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reader := strings.NewReader(`<?xml version="1.0" encoding="iso-8859-1"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en" lang="en">
<head>
<title>404 - Not Found</title>
</head>
<body>
<h1>404 - Not Found</h1>
</body>
</html>`)
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", fmt.Sprintf("%d", reader.Len()))
w.WriteHeader(http.StatusNotFound)
io.Copy(w, reader)
}))
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetUserData()
assert.Error(t, err)
assert.Empty(t, resp)
aerr, ok := err.(awserr.Error)
assert.True(t, ok)
assert.Equal(t, "NotFoundError", aerr.Code())
}
func TestGetRegion(t *testing.T) {
server := initTestServer(
"/latest/meta-data/placement/availability-zone",
"us-west-2a", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
region, err := c.Region()
assert.NoError(t, err)
assert.Equal(t, "us-west-2", region)
}
func TestMetadataAvailable(t *testing.T) {
server := initTestServer(
"/latest/meta-data/instance-id",
"instance-id",
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
available := c.Available()
assert.True(t, available)
}
func TestMetadataIAMInfo_success(t *testing.T) {
server := initTestServer(
"/latest/meta-data/iam/info",
validIamInfo,
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
iamInfo, err := c.IAMInfo()
assert.NoError(t, err)
assert.Equal(t, "Success", iamInfo.Code)
assert.Equal(t, "arn:aws:iam::123456789012:instance-profile/my-instance-profile", iamInfo.InstanceProfileArn)
assert.Equal(t, "AIPAABCDEFGHIJKLMN123", iamInfo.InstanceProfileID)
}
func TestMetadataIAMInfo_failure(t *testing.T) {
server := initTestServer(
"/latest/meta-data/iam/info",
unsuccessfulIamInfo,
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
iamInfo, err := c.IAMInfo()
assert.NotNil(t, err)
assert.Equal(t, "", iamInfo.Code)
assert.Equal(t, "", iamInfo.InstanceProfileArn)
assert.Equal(t, "", iamInfo.InstanceProfileID)
}
func TestMetadataNotAvailable(t *testing.T) {
c := ec2metadata.New(unit.Session)
c.Handlers.Send.Clear()
c.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
r.Error = awserr.New("RequestError", "send request failed", nil)
r.Retryable = aws.Bool(true) // network errors are retryable
})
available := c.Available()
assert.False(t, available)
}
func TestMetadataErrorResponse(t *testing.T) {
c := ec2metadata.New(unit.Session)
c.Handlers.Send.Clear()
c.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest),
Body: ioutil.NopCloser(strings.NewReader("error message text")),
}
r.Retryable = aws.Bool(false) // network errors are retryable
})
data, err := c.GetMetadata("uri/path")
assert.Empty(t, data)
assert.Contains(t, err.Error(), "error message text")
}
func TestEC2RoleProviderInstanceIdentity(t *testing.T) {
server := initTestServer(
"/latest/dynamic/instance-identity/document",
instanceIdentityDocument,
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
doc, err := c.GetInstanceIdentityDocument()
assert.Nil(t, err, "Expect no error, %v", err)
assert.Equal(t, doc.AccountID, "123456789012")
assert.Equal(t, doc.AvailabilityZone, "us-east-1d")
assert.Equal(t, doc.Region, "us-east-1")
}

View file

@ -0,0 +1,78 @@
package ec2metadata_test
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)
func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session)
assert.NotEqual(t, http.DefaultClient, svc.Config.HTTPClient)
assert.Equal(t, 5*time.Second, svc.Config.HTTPClient.Timeout)
}
func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
http.DefaultClient.Transport = &http.Transport{}
defer func() {
http.DefaultClient.Transport = nil
}()
svc := ec2metadata.New(unit.Session)
assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.NotNil(t, tr)
assert.Nil(t, tr.Dial)
}
func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true))
assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
}
func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("us-east-1a"))
}))
cfg := aws.NewConfig().WithEndpoint(server.URL)
runEC2MetadataClients(t, cfg, 100)
}
func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("us-east-1a"))
}))
cfg := aws.NewConfig().WithEndpoint(server.URL).WithHTTPClient(&http.Client{
Transport: http.DefaultTransport,
})
runEC2MetadataClients(t, cfg, 100)
}
func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
var wg sync.WaitGroup
wg.Add(atOnce)
for i := 0; i < atOnce; i++ {
go func() {
svc := ec2metadata.New(unit.Session, cfg)
_, err := svc.Region()
assert.NoError(t, err)
wg.Done()
}()
}
wg.Wait()
}

View file

@ -0,0 +1,117 @@
package endpoints
import (
"strings"
"testing"
)
func TestDecodeEndpoints_V3(t *testing.T) {
const v3Doc = `
{
"version": 3,
"partitions": [
{
"defaults": {
"hostname": "{service}.{region}.{dnsSuffix}",
"protocols": [
"https"
],
"signatureVersions": [
"v4"
]
},
"dnsSuffix": "amazonaws.com",
"partition": "aws",
"partitionName": "AWS Standard",
"regionRegex": "^(us|eu|ap|sa|ca)\\-\\w+\\-\\d+$",
"regions": {
"ap-northeast-1": {
"description": "Asia Pacific (Tokyo)"
}
},
"services": {
"acm": {
"endpoints": {
"ap-northeast-1": {}
}
},
"s3": {
"endpoints": {
"ap-northeast-1": {}
}
}
}
}
]
}`
resolver, err := DecodeModel(strings.NewReader(v3Doc))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
endpoint, err := resolver.EndpointFor("acm", "ap-northeast-1")
if err != nil {
t.Fatalf("failed to resolve endpoint, %v", err)
}
if a, e := endpoint.URL, "https://acm.ap-northeast-1.amazonaws.com"; a != e {
t.Errorf("expected %q URL got %q", e, a)
}
p := resolver.(partitions)[0]
s3Defaults := p.Services["s3"].Defaults
if a, e := s3Defaults.HasDualStack, boxedTrue; a != e {
t.Errorf("expect s3 service to have dualstack enabled")
}
if a, e := s3Defaults.DualStackHostname, "{service}.dualstack.{region}.{dnsSuffix}"; a != e {
t.Errorf("expect s3 dualstack host pattern to be %q, got %q", e, a)
}
ec2metaEndpoint := p.Services["ec2metadata"].Endpoints["aws-global"]
if a, e := ec2metaEndpoint.Hostname, "169.254.169.254/latest"; a != e {
t.Errorf("expect ec2metadata host to be %q, got %q", e, a)
}
}
func TestDecodeEndpoints_NoPartitions(t *testing.T) {
const doc = `{ "version": 3 }`
resolver, err := DecodeModel(strings.NewReader(doc))
if err == nil {
t.Fatalf("expected error")
}
if resolver != nil {
t.Errorf("expect resolver to be nil")
}
}
func TestDecodeEndpoints_UnsupportedVersion(t *testing.T) {
const doc = `{ "version": 2 }`
resolver, err := DecodeModel(strings.NewReader(doc))
if err == nil {
t.Fatalf("expected error decoding model")
}
if resolver != nil {
t.Errorf("expect resolver to be nil")
}
}
func TestDecodeModelOptionsSet(t *testing.T) {
var actual DecodeModelOptions
actual.Set(func(o *DecodeModelOptions) {
o.SkipCustomizations = true
})
expect := DecodeModelOptions{
SkipCustomizations: true,
}
if actual != expect {
t.Errorf("expect %v options got %v", expect, actual)
}
}

View file

@ -1,4 +1,4 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
// Code generated by aws/endpoints/v3model_codegen.go. DO NOT EDIT.
package endpoints
@ -60,6 +60,7 @@ const (
CodecommitServiceID = "codecommit" // Codecommit.
CodedeployServiceID = "codedeploy" // Codedeploy.
CodepipelineServiceID = "codepipeline" // Codepipeline.
CodestarServiceID = "codestar" // Codestar.
CognitoIdentityServiceID = "cognito-identity" // CognitoIdentity.
CognitoIdpServiceID = "cognito-idp" // CognitoIdp.
CognitoSyncServiceID = "cognito-sync" // CognitoSync.
@ -83,6 +84,7 @@ const (
ElasticmapreduceServiceID = "elasticmapreduce" // Elasticmapreduce.
ElastictranscoderServiceID = "elastictranscoder" // Elastictranscoder.
EmailServiceID = "email" // Email.
EntitlementMarketplaceServiceID = "entitlement.marketplace" // EntitlementMarketplace.
EsServiceID = "es" // Es.
EventsServiceID = "events" // Events.
FirehoseServiceID = "firehose" // Firehose.
@ -103,9 +105,12 @@ const (
MarketplacecommerceanalyticsServiceID = "marketplacecommerceanalytics" // Marketplacecommerceanalytics.
MeteringMarketplaceServiceID = "metering.marketplace" // MeteringMarketplace.
MobileanalyticsServiceID = "mobileanalytics" // Mobileanalytics.
ModelsLexServiceID = "models.lex" // ModelsLex.
MonitoringServiceID = "monitoring" // Monitoring.
MturkRequesterServiceID = "mturk-requester" // MturkRequester.
OpsworksServiceID = "opsworks" // Opsworks.
OpsworksCmServiceID = "opsworks-cm" // OpsworksCm.
OrganizationsServiceID = "organizations" // Organizations.
PinpointServiceID = "pinpoint" // Pinpoint.
PollyServiceID = "polly" // Polly.
RdsServiceID = "rds" // Rds.
@ -129,8 +134,10 @@ const (
StsServiceID = "sts" // Sts.
SupportServiceID = "support" // Support.
SwfServiceID = "swf" // Swf.
TaggingServiceID = "tagging" // Tagging.
WafServiceID = "waf" // Waf.
WafRegionalServiceID = "waf-regional" // WafRegional.
WorkdocsServiceID = "workdocs" // Workdocs.
WorkspacesServiceID = "workspaces" // Workspaces.
XrayServiceID = "xray" // Xray.
)
@ -138,17 +145,20 @@ const (
// DefaultResolver returns an Endpoint resolver that will be able
// to resolve endpoints for: AWS Standard, AWS China, and AWS GovCloud (US).
//
// Casting the return value of this func to a EnumPartitions will
// allow you to get a list of the partitions in the order the endpoints
// will be resolved in.
// Use DefaultPartitions() to get the list of the default partitions.
func DefaultResolver() Resolver {
return defaultPartitions
}
// DefaultPartitions returns a list of the partitions the SDK is bundled
// with. The available partitions are: AWS Standard, AWS China, and AWS GovCloud (US).
//
// resolver := endpoints.DefaultResolver()
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
// partitions := endpoints.DefaultPartitions
// for _, p := range partitions {
// // ... inspect partitions
// }
func DefaultResolver() Resolver {
return defaultPartitions
func DefaultPartitions() []Partition {
return defaultPartitions.Partitions()
}
var defaultPartitions = partitions{
@ -246,6 +256,7 @@ var awsPartition = partition{
Endpoints: endpoints{
"ap-northeast-1": endpoint{},
"ap-northeast-2": endpoint{},
"ap-south-1": endpoint{},
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"eu-central-1": endpoint{},
@ -342,6 +353,7 @@ var awsPartition = partition{
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
@ -432,10 +444,14 @@ var awsPartition = partition{
"codebuild": service{
Endpoints: endpoints{
"eu-west-1": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
"ap-northeast-1": endpoint{},
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
},
},
"codecommit": service{
@ -480,14 +496,25 @@ var awsPartition = partition{
"us-west-2": endpoint{},
},
},
"codestar": service{
Endpoints: endpoints{
"eu-west-1": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
},
},
"cognito-identity": service{
Endpoints: endpoints{
"ap-northeast-1": endpoint{},
"ap-northeast-2": endpoint{},
"ap-south-1": endpoint{},
"ap-southeast-2": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
@ -498,9 +525,11 @@ var awsPartition = partition{
Endpoints: endpoints{
"ap-northeast-1": endpoint{},
"ap-northeast-2": endpoint{},
"ap-south-1": endpoint{},
"ap-southeast-2": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
@ -511,9 +540,11 @@ var awsPartition = partition{
Endpoints: endpoints{
"ap-northeast-1": endpoint{},
"ap-northeast-2": endpoint{},
"ap-south-1": endpoint{},
"ap-southeast-2": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
@ -749,10 +780,11 @@ var awsPartition = partition{
"elasticfilesystem": service{
Endpoints: endpoints{
"eu-west-1": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
"ap-southeast-2": endpoint{},
"eu-west-1": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
},
},
"elasticloadbalancing": service{
@ -823,6 +855,16 @@ var awsPartition = partition{
"us-west-2": endpoint{},
},
},
"entitlement.marketplace": service{
Defaults: endpoint{
CredentialScope: credentialScope{
Service: "aws-marketplace",
},
},
Endpoints: endpoints{
"us-east-1": endpoint{},
},
},
"es": service{
Endpoints: endpoints{
@ -831,8 +873,10 @@ var awsPartition = partition{
"ap-south-1": endpoint{},
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"ca-central-1": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
"sa-east-1": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
@ -848,6 +892,7 @@ var awsPartition = partition{
"ap-south-1": endpoint{},
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"ca-central-1": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
@ -958,6 +1003,7 @@ var awsPartition = partition{
"ap-southeast-2": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-2": endpoint{},
@ -1014,6 +1060,7 @@ var awsPartition = partition{
Endpoints: endpoints{
"ap-northeast-1": endpoint{},
"ap-northeast-2": endpoint{},
"ap-south-1": endpoint{},
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"eu-central-1": endpoint{},
@ -1075,10 +1122,13 @@ var awsPartition = partition{
"ap-south-1": endpoint{},
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"ca-central-1": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
"sa-east-1": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-1": endpoint{},
"us-west-2": endpoint{},
},
@ -1089,6 +1139,16 @@ var awsPartition = partition{
"us-east-1": endpoint{},
},
},
"models.lex": service{
Defaults: endpoint{
CredentialScope: credentialScope{
Service: "lex",
},
},
Endpoints: endpoints{
"us-east-1": endpoint{},
},
},
"monitoring": service{
Defaults: endpoint{
Protocols: []string{"http", "https"},
@ -1110,6 +1170,16 @@ var awsPartition = partition{
"us-west-2": endpoint{},
},
},
"mturk-requester": service{
IsRegionalized: boxedFalse,
Endpoints: endpoints{
"sandbox": endpoint{
Hostname: "mturk-requester-sandbox.us-east-1.amazonaws.com",
},
"us-east-1": endpoint{},
},
},
"opsworks": service{
Endpoints: endpoints{
@ -1136,6 +1206,19 @@ var awsPartition = partition{
"us-west-2": endpoint{},
},
},
"organizations": service{
PartitionEndpoint: "aws-global",
IsRegionalized: boxedFalse,
Endpoints: endpoints{
"aws-global": endpoint{
Hostname: "organizations.us-east-1.amazonaws.com",
CredentialScope: credentialScope{
Region: "us-east-1",
},
},
},
},
"pinpoint": service{
Defaults: endpoint{
CredentialScope: credentialScope{
@ -1346,7 +1429,6 @@ var awsPartition = partition{
Endpoints: endpoints{
"ap-south-1": endpoint{},
"ap-southeast-2": endpoint{},
"ca-central-1": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
@ -1421,6 +1503,7 @@ var awsPartition = partition{
Endpoints: endpoints{
"ap-northeast-1": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
@ -1432,6 +1515,7 @@ var awsPartition = partition{
Endpoints: endpoints{
"ap-northeast-1": endpoint{},
"ap-northeast-2": endpoint{},
"ap-south-1": endpoint{},
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"ca-central-1": endpoint{},
@ -1532,6 +1616,25 @@ var awsPartition = partition{
"us-west-2": endpoint{},
},
},
"tagging": service{
Endpoints: endpoints{
"ap-northeast-1": endpoint{},
"ap-northeast-2": endpoint{},
"ap-south-1": endpoint{},
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"ca-central-1": endpoint{},
"eu-central-1": endpoint{},
"eu-west-1": endpoint{},
"eu-west-2": endpoint{},
"sa-east-1": endpoint{},
"us-east-1": endpoint{},
"us-east-2": endpoint{},
"us-west-1": endpoint{},
"us-west-2": endpoint{},
},
},
"waf": service{
PartitionEndpoint: "aws-global",
IsRegionalized: boxedFalse,
@ -1554,6 +1657,17 @@ var awsPartition = partition{
"us-west-2": endpoint{},
},
},
"workdocs": service{
Endpoints: endpoints{
"ap-northeast-1": endpoint{},
"ap-southeast-1": endpoint{},
"ap-southeast-2": endpoint{},
"eu-west-1": endpoint{},
"us-east-1": endpoint{},
"us-west-2": endpoint{},
},
},
"workspaces": service{
Endpoints: endpoints{
@ -1632,6 +1746,12 @@ var awscnPartition = partition{
"cn-north-1": endpoint{},
},
},
"codedeploy": service{
Endpoints: endpoints{
"cn-north-1": endpoint{},
},
},
"config": service{
Endpoints: endpoints{
@ -1809,6 +1929,12 @@ var awscnPartition = partition{
},
"swf": service{
Endpoints: endpoints{
"cn-north-1": endpoint{},
},
},
"tagging": service{
Endpoints: endpoints{
"cn-north-1": endpoint{},
},
@ -1946,6 +2072,12 @@ var awsusgovPartition = partition{
},
},
},
"kinesis": service{
Endpoints: endpoints{
"us-gov-west-1": endpoint{},
},
},
"kms": service{
Endpoints: endpoints{

View file

@ -124,6 +124,49 @@ type EnumPartitions interface {
Partitions() []Partition
}
// RegionsForService returns a map of regions for the partition and service.
// If either the partition or service does not exist false will be returned
// as the second parameter.
//
// This example shows how to get the regions for DynamoDB in the AWS partition.
// rs, exists := endpoints.RegionsForService(endpoints.DefaultPartitions(), endpoints.AwsPartitionID, endpoints.DynamodbServiceID)
//
// This is equivalent to using the partition directly.
// rs := endpoints.AwsPartition().Services()[endpoints.DynamodbServiceID].Regions()
func RegionsForService(ps []Partition, partitionID, serviceID string) (map[string]Region, bool) {
for _, p := range ps {
if p.ID() != partitionID {
continue
}
if _, ok := p.p.Services[serviceID]; !ok {
break
}
s := Service{
id: serviceID,
p: p.p,
}
return s.Regions(), true
}
return map[string]Region{}, false
}
// PartitionForRegion returns the first partition which includes the region
// passed in. This includes both known regions and regions which match
// a pattern supported by the partition which may include regions that are
// not explicitly known by the partition. Use the Regions method of the
// returned Partition if explicit support is needed.
func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) {
for _, p := range ps {
if _, ok := p.p.Regions[regionID]; ok || p.p.RegionRegex.MatchString(regionID) {
return p, true
}
}
return Partition{}, false
}
// A Partition provides the ability to enumerate the partition's regions
// and services.
type Partition struct {
@ -132,7 +175,7 @@ type Partition struct {
}
// ID returns the identifier of the partition.
func (p *Partition) ID() string { return p.id }
func (p Partition) ID() string { return p.id }
// EndpointFor attempts to resolve the endpoint based on service and region.
// See Options for information on configuring how the endpoint is resolved.
@ -155,13 +198,13 @@ func (p *Partition) ID() string { return p.id }
// Errors that can be returned.
// * UnknownServiceError
// * UnknownEndpointError
func (p *Partition) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return p.p.EndpointFor(service, region, opts...)
}
// Regions returns a map of Regions indexed by their ID. This is useful for
// enumerating over the regions in a partition.
func (p *Partition) Regions() map[string]Region {
func (p Partition) Regions() map[string]Region {
rs := map[string]Region{}
for id := range p.p.Regions {
rs[id] = Region{
@ -175,7 +218,7 @@ func (p *Partition) Regions() map[string]Region {
// Services returns a map of Service indexed by their ID. This is useful for
// enumerating over the services in a partition.
func (p *Partition) Services() map[string]Service {
func (p Partition) Services() map[string]Service {
ss := map[string]Service{}
for id := range p.p.Services {
ss[id] = Service{
@ -195,16 +238,16 @@ type Region struct {
}
// ID returns the region's identifier.
func (r *Region) ID() string { return r.id }
func (r Region) ID() string { return r.id }
// ResolveEndpoint resolves an endpoint from the context of the region given
// a service. See Partition.EndpointFor for usage and errors that can be returned.
func (r *Region) ResolveEndpoint(service string, opts ...func(*Options)) (ResolvedEndpoint, error) {
func (r Region) ResolveEndpoint(service string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return r.p.EndpointFor(service, r.id, opts...)
}
// Services returns a list of all services that are known to be in this region.
func (r *Region) Services() map[string]Service {
func (r Region) Services() map[string]Service {
ss := map[string]Service{}
for id, s := range r.p.Services {
if _, ok := s.Endpoints[r.id]; ok {
@ -226,17 +269,38 @@ type Service struct {
}
// ID returns the identifier for the service.
func (s *Service) ID() string { return s.id }
func (s Service) ID() string { return s.id }
// ResolveEndpoint resolves an endpoint from the context of a service given
// a region. See Partition.EndpointFor for usage and errors that can be returned.
func (s *Service) ResolveEndpoint(region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
func (s Service) ResolveEndpoint(region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return s.p.EndpointFor(s.id, region, opts...)
}
// Regions returns a map of Regions that the service is present in.
//
// A region is the AWS region the service exists in. Whereas a Endpoint is
// an URL that can be resolved to a instance of a service.
func (s Service) Regions() map[string]Region {
rs := map[string]Region{}
for id := range s.p.Services[s.id].Endpoints {
if _, ok := s.p.Regions[id]; ok {
rs[id] = Region{
id: id,
p: s.p,
}
}
}
return rs
}
// Endpoints returns a map of Endpoints indexed by their ID for all known
// endpoints for a service.
func (s *Service) Endpoints() map[string]Endpoint {
//
// A region is the AWS region the service exists in. Whereas a Endpoint is
// an URL that can be resolved to a instance of a service.
func (s Service) Endpoints() map[string]Endpoint {
es := map[string]Endpoint{}
for id := range s.p.Services[s.id].Endpoints {
es[id] = Endpoint{
@ -259,15 +323,15 @@ type Endpoint struct {
}
// ID returns the identifier for an endpoint.
func (e *Endpoint) ID() string { return e.id }
func (e Endpoint) ID() string { return e.id }
// ServiceID returns the identifier the endpoint belongs to.
func (e *Endpoint) ServiceID() string { return e.serviceID }
func (e Endpoint) ServiceID() string { return e.serviceID }
// ResolveEndpoint resolves an endpoint from the context of a service and
// region the endpoint represents. See Partition.EndpointFor for usage and
// errors that can be returned.
func (e *Endpoint) ResolveEndpoint(opts ...func(*Options)) (ResolvedEndpoint, error) {
func (e Endpoint) ResolveEndpoint(opts ...func(*Options)) (ResolvedEndpoint, error) {
return e.p.EndpointFor(e.serviceID, e.id, opts...)
}
@ -300,28 +364,6 @@ type EndpointNotFoundError struct {
Region string
}
//// NewEndpointNotFoundError builds and returns NewEndpointNotFoundError.
//func NewEndpointNotFoundError(p, s, r string) EndpointNotFoundError {
// return EndpointNotFoundError{
// awsError: awserr.New("EndpointNotFoundError", "unable to find endpoint", nil),
// Partition: p,
// Service: s,
// Region: r,
// }
//}
//
//// Error returns string representation of the error.
//func (e EndpointNotFoundError) Error() string {
// extra := fmt.Sprintf("partition: %q, service: %q, region: %q",
// e.Partition, e.Service, e.Region)
// return awserr.SprintError(e.Code(), e.Message(), extra, e.OrigErr())
//}
//
//// String returns the string representation of the error.
//func (e EndpointNotFoundError) String() string {
// return e.Error()
//}
// A UnknownServiceError is returned when the service does not resolve to an
// endpoint. Includes a list of all known services for the partition. Returned
// when a partition does not support the service.

View file

@ -0,0 +1,335 @@
package endpoints
import "testing"
func TestEnumDefaultPartitions(t *testing.T) {
resolver := DefaultResolver()
enum, ok := resolver.(EnumPartitions)
if ok != true {
t.Fatalf("resolver must satisfy EnumPartition interface")
}
ps := enum.Partitions()
if a, e := len(ps), len(defaultPartitions); a != e {
t.Errorf("expected %d partitions, got %d", e, a)
}
}
func TestEnumDefaultRegions(t *testing.T) {
expectPart := defaultPartitions[0]
partEnum := defaultPartitions[0].Partition()
regEnum := partEnum.Regions()
if a, e := len(regEnum), len(expectPart.Regions); a != e {
t.Errorf("expected %d regions, got %d", e, a)
}
}
func TestEnumPartitionServices(t *testing.T) {
expectPart := testPartitions[0]
partEnum := testPartitions[0].Partition()
if a, e := partEnum.ID(), "part-id"; a != e {
t.Errorf("expect %q partition ID, got %q", e, a)
}
svcEnum := partEnum.Services()
if a, e := len(svcEnum), len(expectPart.Services); a != e {
t.Errorf("expected %d regions, got %d", e, a)
}
}
func TestEnumRegionServices(t *testing.T) {
p := testPartitions[0].Partition()
rs := p.Regions()
if a, e := len(rs), 2; a != e {
t.Errorf("expect %d regions got %d", e, a)
}
if _, ok := rs["us-east-1"]; !ok {
t.Errorf("expect us-east-1 region to be found, was not")
}
if _, ok := rs["us-west-2"]; !ok {
t.Errorf("expect us-west-2 region to be found, was not")
}
r := rs["us-east-1"]
if a, e := r.ID(), "us-east-1"; a != e {
t.Errorf("expect %q region ID, got %q", e, a)
}
ss := r.Services()
if a, e := len(ss), 1; a != e {
t.Errorf("expect %d services for us-east-1, got %d", e, a)
}
if _, ok := ss["service1"]; !ok {
t.Errorf("expect service1 service to be found, was not")
}
resolved, err := r.ResolveEndpoint("service1")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service1.us-east-1.amazonaws.com"; a != e {
t.Errorf("expect %q resolved URL, got %q", e, a)
}
}
func TestEnumServiceRegions(t *testing.T) {
p := testPartitions[0].Partition()
rs := p.Services()["service1"].Regions()
if e, a := 2, len(rs); e != a {
t.Errorf("expect %d regions, got %d", e, a)
}
if _, ok := rs["us-east-1"]; !ok {
t.Errorf("expect region to be found")
}
if _, ok := rs["us-west-2"]; !ok {
t.Errorf("expect region to be found")
}
}
func TestEnumServicesEndpoints(t *testing.T) {
p := testPartitions[0].Partition()
ss := p.Services()
if a, e := len(ss), 5; a != e {
t.Errorf("expect %d regions got %d", e, a)
}
if _, ok := ss["service1"]; !ok {
t.Errorf("expect service1 region to be found, was not")
}
if _, ok := ss["service2"]; !ok {
t.Errorf("expect service2 region to be found, was not")
}
s := ss["service1"]
if a, e := s.ID(), "service1"; a != e {
t.Errorf("expect %q service ID, got %q", e, a)
}
resolved, err := s.ResolveEndpoint("us-west-2")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service1.us-west-2.amazonaws.com"; a != e {
t.Errorf("expect %q resolved URL, got %q", e, a)
}
}
func TestEnumEndpoints(t *testing.T) {
p := testPartitions[0].Partition()
s := p.Services()["service1"]
es := s.Endpoints()
if a, e := len(es), 2; a != e {
t.Errorf("expect %d endpoints for service2, got %d", e, a)
}
if _, ok := es["us-east-1"]; !ok {
t.Errorf("expect us-east-1 to be found, was not")
}
e := es["us-east-1"]
if a, e := e.ID(), "us-east-1"; a != e {
t.Errorf("expect %q endpoint ID, got %q", e, a)
}
if a, e := e.ServiceID(), "service1"; a != e {
t.Errorf("expect %q service ID, got %q", e, a)
}
resolved, err := e.ResolveEndpoint()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service1.us-east-1.amazonaws.com"; a != e {
t.Errorf("expect %q resolved URL, got %q", e, a)
}
}
func TestResolveEndpointForPartition(t *testing.T) {
enum := testPartitions.Partitions()[0]
expected, err := testPartitions.EndpointFor("service1", "us-east-1")
actual, err := enum.EndpointFor("service1", "us-east-1")
if err != nil {
t.Fatalf("unexpected error, %v", err)
}
if expected != actual {
t.Errorf("expect resolved endpoint to be %v, but got %v", expected, actual)
}
}
func TestAddScheme(t *testing.T) {
cases := []struct {
In string
Expect string
DisableSSL bool
}{
{
In: "https://example.com",
Expect: "https://example.com",
},
{
In: "example.com",
Expect: "https://example.com",
},
{
In: "http://example.com",
Expect: "http://example.com",
},
{
In: "example.com",
Expect: "http://example.com",
DisableSSL: true,
},
{
In: "https://example.com",
Expect: "https://example.com",
DisableSSL: true,
},
}
for i, c := range cases {
actual := AddScheme(c.In, c.DisableSSL)
if actual != c.Expect {
t.Errorf("%d, expect URL to be %q, got %q", i, c.Expect, actual)
}
}
}
func TestResolverFunc(t *testing.T) {
var resolver Resolver
resolver = ResolverFunc(func(s, r string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return ResolvedEndpoint{
URL: "https://service.region.dnssuffix.com",
SigningRegion: "region",
SigningName: "service",
}, nil
})
resolved, err := resolver.EndpointFor("service", "region", func(o *Options) {
o.DisableSSL = true
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service.region.dnssuffix.com"; a != e {
t.Errorf("expect %q endpoint URL, got %q", e, a)
}
if a, e := resolved.SigningRegion, "region"; a != e {
t.Errorf("expect %q region, got %q", e, a)
}
if a, e := resolved.SigningName, "service"; a != e {
t.Errorf("expect %q signing name, got %q", e, a)
}
}
func TestOptionsSet(t *testing.T) {
var actual Options
actual.Set(DisableSSLOption, UseDualStackOption, StrictMatchingOption)
expect := Options{
DisableSSL: true,
UseDualStack: true,
StrictMatching: true,
}
if actual != expect {
t.Errorf("expect %v options got %v", expect, actual)
}
}
func TestRegionsForService(t *testing.T) {
ps := DefaultPartitions()
var expect map[string]Region
var serviceID string
for _, s := range ps[0].Services() {
expect = s.Regions()
serviceID = s.ID()
if len(expect) > 0 {
break
}
}
actual, ok := RegionsForService(ps, ps[0].ID(), serviceID)
if !ok {
t.Fatalf("expect regions to be found, was not")
}
if len(actual) == 0 {
t.Fatalf("expect service %s to have regions", serviceID)
}
if e, a := len(expect), len(actual); e != a {
t.Fatalf("expect %d regions, got %d", e, a)
}
for id, r := range actual {
if e, a := id, r.ID(); e != a {
t.Errorf("expect %s region id, got %s", e, a)
}
if _, ok := expect[id]; !ok {
t.Errorf("expect %s region to be found", id)
}
}
}
func TestRegionsForService_NotFound(t *testing.T) {
ps := testPartitions.Partitions()
actual, ok := RegionsForService(ps, ps[0].ID(), "service-not-exists")
if ok {
t.Fatalf("expect no regions to be found, but were")
}
if len(actual) != 0 {
t.Errorf("expect no regions, got %v", actual)
}
}
func TestPartitionForRegion(t *testing.T) {
ps := DefaultPartitions()
expect := ps[len(ps)%2]
var regionID string
for id := range expect.Regions() {
regionID = id
break
}
actual, ok := PartitionForRegion(ps, regionID)
if !ok {
t.Fatalf("expect partition to be found")
}
if e, a := expect.ID(), actual.ID(); e != a {
t.Errorf("expect %s partition, got %s", e, a)
}
}
func TestPartitionForRegion_NotFound(t *testing.T) {
ps := DefaultPartitions()
actual, ok := PartitionForRegion(ps, "regionNotExists")
if ok {
t.Errorf("expect no partition to be found, got %v", actual)
}
}

View file

@ -0,0 +1,66 @@
package endpoints_test
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sqs"
)
func ExampleEnumPartitions() {
resolver := endpoints.DefaultResolver()
partitions := resolver.(endpoints.EnumPartitions).Partitions()
for _, p := range partitions {
fmt.Println("Regions for", p.ID())
for id := range p.Regions() {
fmt.Println("*", id)
}
fmt.Println("Services for", p.ID())
for id := range p.Services() {
fmt.Println("*", id)
}
}
}
func ExampleResolverFunc() {
myCustomResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if service == endpoints.S3ServiceID {
return endpoints.ResolvedEndpoint{
URL: "s3.custom.endpoint.com",
SigningRegion: "custom-signing-region",
}, nil
}
return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
}
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String("us-west-2"),
EndpointResolver: endpoints.ResolverFunc(myCustomResolver),
}))
// Create the S3 service client with the shared session. This will
// automatically use the S3 custom endpoint configured in the custom
// endpoint resolver wrapping the default endpoint resolver.
s3Svc := s3.New(sess)
// Operation calls will be made to the custom endpoint.
s3Svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String("myBucket"),
Key: aws.String("myObjectKey"),
})
// Create the SQS service client with the shared session. This will
// fallback to the default endpoint resolver because the customization
// passes any non S3 service endpoint resolve to the default resolver.
sqsSvc := sqs.New(sess)
// Operation calls will be made to the default endpoint for SQS for the
// region configured.
sqsSvc.ReceiveMessage(&sqs.ReceiveMessageInput{
QueueUrl: aws.String("my-queue-url"),
})
}

View file

@ -158,7 +158,7 @@ var funcMap = template.FuncMap{
const v3Tmpl = `
{{ define "defaults" -}}
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
// Code generated by aws/endpoints/v3model_codegen.go. DO NOT EDIT.
package endpoints
@ -209,17 +209,20 @@ import (
// DefaultResolver returns an Endpoint resolver that will be able
// to resolve endpoints for: {{ ListPartitionNames . }}.
//
// Casting the return value of this func to a EnumPartitions will
// allow you to get a list of the partitions in the order the endpoints
// will be resolved in.
// Use DefaultPartitions() to get the list of the default partitions.
func DefaultResolver() Resolver {
return defaultPartitions
}
// DefaultPartitions returns a list of the partitions the SDK is bundled
// with. The available partitions are: {{ ListPartitionNames . }}.
//
// resolver := endpoints.DefaultResolver()
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
// partitions := endpoints.DefaultPartitions
// for _, p := range partitions {
// // ... inspect partitions
// }
func DefaultResolver() Resolver {
return defaultPartitions
func DefaultPartitions() []Partition {
return defaultPartitions.Partitions()
}
var defaultPartitions = partitions{

View file

@ -0,0 +1,354 @@
package endpoints
import (
"encoding/json"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
)
func TestUnmarshalRegionRegex(t *testing.T) {
var input = []byte(`
{
"regionRegex": "^(us|eu|ap|sa|ca)\\-\\w+\\-\\d+$"
}`)
p := partition{}
err := json.Unmarshal(input, &p)
assert.NoError(t, err)
expectRegexp, err := regexp.Compile(`^(us|eu|ap|sa|ca)\-\w+\-\d+$`)
assert.NoError(t, err)
assert.Equal(t, expectRegexp.String(), p.RegionRegex.Regexp.String())
}
func TestUnmarshalRegion(t *testing.T) {
var input = []byte(`
{
"aws-global": {
"description": "AWS partition-global endpoint"
},
"us-east-1": {
"description": "US East (N. Virginia)"
}
}`)
rs := regions{}
err := json.Unmarshal(input, &rs)
assert.NoError(t, err)
assert.Len(t, rs, 2)
r, ok := rs["aws-global"]
assert.True(t, ok)
assert.Equal(t, "AWS partition-global endpoint", r.Description)
r, ok = rs["us-east-1"]
assert.True(t, ok)
assert.Equal(t, "US East (N. Virginia)", r.Description)
}
func TestUnmarshalServices(t *testing.T) {
var input = []byte(`
{
"acm": {
"endpoints": {
"us-east-1": {}
}
},
"apigateway": {
"isRegionalized": true,
"endpoints": {
"us-east-1": {},
"us-west-2": {}
}
},
"notRegionalized": {
"isRegionalized": false,
"endpoints": {
"us-east-1": {},
"us-west-2": {}
}
}
}`)
ss := services{}
err := json.Unmarshal(input, &ss)
assert.NoError(t, err)
assert.Len(t, ss, 3)
s, ok := ss["acm"]
assert.True(t, ok)
assert.Len(t, s.Endpoints, 1)
assert.Equal(t, boxedBoolUnset, s.IsRegionalized)
s, ok = ss["apigateway"]
assert.True(t, ok)
assert.Len(t, s.Endpoints, 2)
assert.Equal(t, boxedTrue, s.IsRegionalized)
s, ok = ss["notRegionalized"]
assert.True(t, ok)
assert.Len(t, s.Endpoints, 2)
assert.Equal(t, boxedFalse, s.IsRegionalized)
}
func TestUnmarshalEndpoints(t *testing.T) {
var inputs = []byte(`
{
"aws-global": {
"hostname": "cloudfront.amazonaws.com",
"protocols": [
"http",
"https"
],
"signatureVersions": [ "v4" ],
"credentialScope": {
"region": "us-east-1",
"service": "serviceName"
},
"sslCommonName": "commonName"
},
"us-east-1": {}
}`)
es := endpoints{}
err := json.Unmarshal(inputs, &es)
assert.NoError(t, err)
assert.Len(t, es, 2)
s, ok := es["aws-global"]
assert.True(t, ok)
assert.Equal(t, "cloudfront.amazonaws.com", s.Hostname)
assert.Equal(t, []string{"http", "https"}, s.Protocols)
assert.Equal(t, []string{"v4"}, s.SignatureVersions)
assert.Equal(t, credentialScope{"us-east-1", "serviceName"}, s.CredentialScope)
assert.Equal(t, "commonName", s.SSLCommonName)
}
func TestEndpointResolve(t *testing.T) {
defs := []endpoint{
{
Hostname: "{service}.{region}.{dnsSuffix}",
SignatureVersions: []string{"v2"},
SSLCommonName: "sslCommonName",
},
{
Hostname: "other-hostname",
Protocols: []string{"http"},
CredentialScope: credentialScope{
Region: "signing_region",
Service: "signing_service",
},
},
}
e := endpoint{
Hostname: "{service}.{region}.{dnsSuffix}",
Protocols: []string{"http", "https"},
SignatureVersions: []string{"v4"},
SSLCommonName: "new sslCommonName",
}
resolved := e.resolve("service", "region", "dnsSuffix",
defs, Options{},
)
assert.Equal(t, "https://service.region.dnsSuffix", resolved.URL)
assert.Equal(t, "signing_service", resolved.SigningName)
assert.Equal(t, "signing_region", resolved.SigningRegion)
assert.Equal(t, "v4", resolved.SigningMethod)
}
func TestEndpointMergeIn(t *testing.T) {
expected := endpoint{
Hostname: "other hostname",
Protocols: []string{"http"},
SignatureVersions: []string{"v4"},
SSLCommonName: "ssl common name",
CredentialScope: credentialScope{
Region: "region",
Service: "service",
},
}
actual := endpoint{}
actual.mergeIn(endpoint{
Hostname: "other hostname",
Protocols: []string{"http"},
SignatureVersions: []string{"v4"},
SSLCommonName: "ssl common name",
CredentialScope: credentialScope{
Region: "region",
Service: "service",
},
})
assert.Equal(t, expected, actual)
}
var testPartitions = partitions{
partition{
ID: "part-id",
Name: "partitionName",
DNSSuffix: "amazonaws.com",
RegionRegex: regionRegex{
Regexp: func() *regexp.Regexp {
reg, _ := regexp.Compile("^(us|eu|ap|sa|ca)\\-\\w+\\-\\d+$")
return reg
}(),
},
Defaults: endpoint{
Hostname: "{service}.{region}.{dnsSuffix}",
Protocols: []string{"https"},
SignatureVersions: []string{"v4"},
},
Regions: regions{
"us-east-1": region{
Description: "region description",
},
"us-west-2": region{},
},
Services: services{
"s3": service{},
"service1": service{
Endpoints: endpoints{
"us-east-1": {},
"us-west-2": {
HasDualStack: boxedTrue,
DualStackHostname: "{service}.dualstack.{region}.{dnsSuffix}",
},
},
},
"service2": service{},
"httpService": service{
Defaults: endpoint{
Protocols: []string{"http"},
},
},
"globalService": service{
IsRegionalized: boxedFalse,
PartitionEndpoint: "aws-global",
Endpoints: endpoints{
"aws-global": endpoint{
CredentialScope: credentialScope{
Region: "us-east-1",
},
Hostname: "globalService.amazonaws.com",
},
},
},
},
},
}
func TestResolveEndpoint(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-west-2")
assert.NoError(t, err)
assert.Equal(t, "https://service2.us-west-2.amazonaws.com", resolved.URL)
assert.Equal(t, "us-west-2", resolved.SigningRegion)
assert.Equal(t, "service2", resolved.SigningName)
}
func TestResolveEndpoint_DisableSSL(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-west-2", DisableSSLOption)
assert.NoError(t, err)
assert.Equal(t, "http://service2.us-west-2.amazonaws.com", resolved.URL)
assert.Equal(t, "us-west-2", resolved.SigningRegion)
assert.Equal(t, "service2", resolved.SigningName)
}
func TestResolveEndpoint_UseDualStack(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service1", "us-west-2", UseDualStackOption)
assert.NoError(t, err)
assert.Equal(t, "https://service1.dualstack.us-west-2.amazonaws.com", resolved.URL)
assert.Equal(t, "us-west-2", resolved.SigningRegion)
assert.Equal(t, "service1", resolved.SigningName)
}
func TestResolveEndpoint_HTTPProtocol(t *testing.T) {
resolved, err := testPartitions.EndpointFor("httpService", "us-west-2")
assert.NoError(t, err)
assert.Equal(t, "http://httpService.us-west-2.amazonaws.com", resolved.URL)
assert.Equal(t, "us-west-2", resolved.SigningRegion)
assert.Equal(t, "httpService", resolved.SigningName)
}
func TestResolveEndpoint_UnknownService(t *testing.T) {
_, err := testPartitions.EndpointFor("unknownservice", "us-west-2")
assert.Error(t, err)
_, ok := err.(UnknownServiceError)
assert.True(t, ok, "expect error to be UnknownServiceError")
}
func TestResolveEndpoint_ResolveUnknownService(t *testing.T) {
resolved, err := testPartitions.EndpointFor("unknown-service", "us-region-1",
ResolveUnknownServiceOption)
assert.NoError(t, err)
assert.Equal(t, "https://unknown-service.us-region-1.amazonaws.com", resolved.URL)
assert.Equal(t, "us-region-1", resolved.SigningRegion)
assert.Equal(t, "unknown-service", resolved.SigningName)
}
func TestResolveEndpoint_UnknownMatchedRegion(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-region-1")
assert.NoError(t, err)
assert.Equal(t, "https://service2.us-region-1.amazonaws.com", resolved.URL)
assert.Equal(t, "us-region-1", resolved.SigningRegion)
assert.Equal(t, "service2", resolved.SigningName)
}
func TestResolveEndpoint_UnknownRegion(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "unknownregion")
assert.NoError(t, err)
assert.Equal(t, "https://service2.unknownregion.amazonaws.com", resolved.URL)
assert.Equal(t, "unknownregion", resolved.SigningRegion)
assert.Equal(t, "service2", resolved.SigningName)
}
func TestResolveEndpoint_StrictPartitionUnknownEndpoint(t *testing.T) {
_, err := testPartitions[0].EndpointFor("service2", "unknownregion", StrictMatchingOption)
assert.Error(t, err)
_, ok := err.(UnknownEndpointError)
assert.True(t, ok, "expect error to be UnknownEndpointError")
}
func TestResolveEndpoint_StrictPartitionsUnknownEndpoint(t *testing.T) {
_, err := testPartitions.EndpointFor("service2", "us-region-1", StrictMatchingOption)
assert.Error(t, err)
_, ok := err.(UnknownEndpointError)
assert.True(t, ok, "expect error to be UnknownEndpointError")
}
func TestResolveEndpoint_NotRegionalized(t *testing.T) {
resolved, err := testPartitions.EndpointFor("globalService", "us-west-2")
assert.NoError(t, err)
assert.Equal(t, "https://globalService.amazonaws.com", resolved.URL)
assert.Equal(t, "us-east-1", resolved.SigningRegion)
assert.Equal(t, "globalService", resolved.SigningName)
}
func TestResolveEndpoint_AwsGlobal(t *testing.T) {
resolved, err := testPartitions.EndpointFor("globalService", "aws-global")
assert.NoError(t, err)
assert.Equal(t, "https://globalService.amazonaws.com", resolved.URL)
assert.Equal(t, "us-east-1", resolved.SigningRegion)
assert.Equal(t, "globalService", resolved.SigningName)
}

12
vendor/github.com/aws/aws-sdk-go/aws/jsonvalue.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
package aws
// JSONValue is a representation of a grab bag type that will be marshaled
// into a json string. This type can be used just like any other map.
//
// Example:
//
// values := aws.JSONValue{
// "Foo": "Bar",
// }
// values["Baz"] = "Qux"
type JSONValue map[string]interface{}

View file

@ -0,0 +1,19 @@
// +build !appengine
package request
import (
"net"
"os"
"syscall"
)
func isErrConnectionReset(err error) bool {
if opErr, ok := err.(*net.OpError); ok {
if sysErr, ok := opErr.Err.(*os.SyscallError); ok {
return sysErr.Err == syscall.ECONNRESET
}
}
return false
}

View file

@ -0,0 +1,11 @@
// +build appengine
package request
import (
"strings"
)
func isErrConnectionReset(err error) bool {
return strings.Contains(err.Error(), "connection reset")
}

View file

@ -0,0 +1,9 @@
// +build appengine
package request_test
import (
"errors"
)
var stubConnectionResetError = errors.New("connection reset")

View file

@ -0,0 +1,11 @@
// +build !appengine
package request_test
import (
"net"
"os"
"syscall"
)
var stubConnectionResetError = &net.OpError{Err: &os.SyscallError{Syscall: "read", Err: syscall.ECONNRESET}}

View file

@ -18,6 +18,7 @@ type Handlers struct {
UnmarshalError HandlerList
Retry HandlerList
AfterRetry HandlerList
Complete HandlerList
}
// Copy returns of this handler's lists.
@ -33,6 +34,7 @@ func (h *Handlers) Copy() Handlers {
UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(),
Complete: h.Complete.copy(),
}
}
@ -48,6 +50,7 @@ func (h *Handlers) Clear() {
h.ValidateResponse.Clear()
h.Retry.Clear()
h.AfterRetry.Clear()
h.Complete.Clear()
}
// A HandlerListRunItem represents an entry in the HandlerList which
@ -85,13 +88,17 @@ func (l *HandlerList) copy() HandlerList {
n := HandlerList{
AfterEachFn: l.AfterEachFn,
}
n.list = append([]NamedHandler{}, l.list...)
if len(l.list) == 0 {
return n
}
n.list = append(make([]NamedHandler, 0, len(l.list)), l.list...)
return n
}
// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []NamedHandler{}
l.list = l.list[0:0]
}
// Len returns the number of handlers in the list.
@ -101,33 +108,54 @@ func (l *HandlerList) Len() int {
// PushBack pushes handler f to the back of the handler list.
func (l *HandlerList) PushBack(f func(*Request)) {
l.list = append(l.list, NamedHandler{"__anonymous", f})
}
// PushFront pushes handler f to the front of the handler list.
func (l *HandlerList) PushFront(f func(*Request)) {
l.list = append([]NamedHandler{{"__anonymous", f}}, l.list...)
l.PushBackNamed(NamedHandler{"__anonymous", f})
}
// PushBackNamed pushes named handler f to the back of the handler list.
func (l *HandlerList) PushBackNamed(n NamedHandler) {
if cap(l.list) == 0 {
l.list = make([]NamedHandler, 0, 5)
}
l.list = append(l.list, n)
}
// PushFront pushes handler f to the front of the handler list.
func (l *HandlerList) PushFront(f func(*Request)) {
l.PushFrontNamed(NamedHandler{"__anonymous", f})
}
// PushFrontNamed pushes named handler f to the front of the handler list.
func (l *HandlerList) PushFrontNamed(n NamedHandler) {
l.list = append([]NamedHandler{n}, l.list...)
if cap(l.list) == len(l.list) {
// Allocating new list required
l.list = append([]NamedHandler{n}, l.list...)
} else {
// Enough room to prepend into list.
l.list = append(l.list, NamedHandler{})
copy(l.list[1:], l.list)
l.list[0] = n
}
}
// Remove removes a NamedHandler n
func (l *HandlerList) Remove(n NamedHandler) {
newlist := []NamedHandler{}
for _, m := range l.list {
if m.Name != n.Name {
newlist = append(newlist, m)
l.RemoveByName(n.Name)
}
// RemoveByName removes a NamedHandler by name.
func (l *HandlerList) RemoveByName(name string) {
for i := 0; i < len(l.list); i++ {
m := l.list[i]
if m.Name == name {
// Shift array preventing creating new arrays
copy(l.list[i:], l.list[i+1:])
l.list[len(l.list)-1] = NamedHandler{}
l.list = l.list[:len(l.list)-1]
// decrement list so next check to length is correct
i--
}
}
l.list = newlist
}
// Run executes all handlers in the list with a given request object.
@ -163,6 +191,16 @@ func HandlerListStopOnError(item HandlerListRunItem) bool {
return item.Request.Error == nil
}
// WithAppendUserAgent will add a string to the user agent prefixed with a
// single white space.
func WithAppendUserAgent(s string) Option {
return func(r *Request) {
r.Handlers.Build.PushBack(func(r2 *Request) {
AddToUserAgent(r, s)
})
}
}
// MakeAddToUserAgentHandler will add the name/version pair to the User-Agent request
// header. If the extra parameters are provided they will be added as metadata to the
// name/version pair resulting in the following format.

View file

@ -0,0 +1,157 @@
package request_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &request.Request{}
l := request.HandlerList{}
l.PushBack(func(r *request.Request) {
s += "a"
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
}
func TestMultipleHandlers(t *testing.T) {
r := &request.Request{}
l := request.HandlerList{}
l.PushBack(func(r *request.Request) { r.Data = nil })
l.PushFront(func(r *request.Request) { r.Data = aws.Bool(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}
func TestNamedHandlers(t *testing.T) {
l := request.HandlerList{}
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {}}
named2 := request.NamedHandler{Name: "NotName", Fn: func(r *request.Request) {}}
l.PushBackNamed(named)
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.PushBack(func(r *request.Request) {})
assert.Equal(t, 4, l.Len())
l.Remove(named)
assert.Equal(t, 2, l.Len())
}
func TestLoggedHandlers(t *testing.T) {
expectedHandlers := []string{"name1", "name2"}
l := request.HandlerList{}
loggedHandlers := []string{}
l.AfterEachFn = request.HandlerListLogItem
cfg := aws.Config{Logger: aws.LoggerFunc(func(args ...interface{}) {
loggedHandlers = append(loggedHandlers, args[2].(string))
})}
named1 := request.NamedHandler{Name: "name1", Fn: func(r *request.Request) {}}
named2 := request.NamedHandler{Name: "name2", Fn: func(r *request.Request) {}}
l.PushBackNamed(named1)
l.PushBackNamed(named2)
l.Run(&request.Request{Config: cfg})
assert.Equal(t, expectedHandlers, loggedHandlers)
}
func TestStopHandlers(t *testing.T) {
l := request.HandlerList{}
stopAt := 1
l.AfterEachFn = func(item request.HandlerListRunItem) bool {
return item.Index != stopAt
}
called := 0
l.PushBackNamed(request.NamedHandler{Name: "name1", Fn: func(r *request.Request) {
called++
}})
l.PushBackNamed(request.NamedHandler{Name: "name2", Fn: func(r *request.Request) {
called++
}})
l.PushBackNamed(request.NamedHandler{Name: "name3", Fn: func(r *request.Request) {
assert.Fail(t, "third handler should not be called")
}})
l.Run(&request.Request{})
assert.Equal(t, 2, called, "Expect only two handlers to be called")
}
func BenchmarkNewRequest(b *testing.B) {
svc := s3.New(unit.Session)
for i := 0; i < b.N; i++ {
r, _ := svc.GetObjectRequest(nil)
if r == nil {
b.Fatal("r should not be nil")
}
}
}
func BenchmarkHandlersCopy(b *testing.B) {
handlers := request.Handlers{}
handlers.Validate.PushBack(func(r *request.Request) {})
handlers.Validate.PushBack(func(r *request.Request) {})
handlers.Build.PushBack(func(r *request.Request) {})
handlers.Build.PushBack(func(r *request.Request) {})
handlers.Send.PushBack(func(r *request.Request) {})
handlers.Send.PushBack(func(r *request.Request) {})
handlers.Unmarshal.PushBack(func(r *request.Request) {})
handlers.Unmarshal.PushBack(func(r *request.Request) {})
for i := 0; i < b.N; i++ {
h := handlers.Copy()
if e, a := handlers.Validate.Len(), h.Validate.Len(); e != a {
b.Fatalf("expected %d handlers got %d", e, a)
}
}
}
func BenchmarkHandlersPushBack(b *testing.B) {
handlers := request.Handlers{}
for i := 0; i < b.N; i++ {
h := handlers.Copy()
h.Validate.PushBack(func(r *request.Request) {})
h.Validate.PushBack(func(r *request.Request) {})
h.Validate.PushBack(func(r *request.Request) {})
h.Validate.PushBack(func(r *request.Request) {})
}
}
func BenchmarkHandlersPushFront(b *testing.B) {
handlers := request.Handlers{}
for i := 0; i < b.N; i++ {
h := handlers.Copy()
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
}
}
func BenchmarkHandlersClear(b *testing.B) {
handlers := request.Handlers{}
for i := 0; i < b.N; i++ {
h := handlers.Copy()
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Clear()
}
}

View file

@ -0,0 +1,34 @@
package request
import (
"bytes"
"io/ioutil"
"net/http"
"net/url"
"sync"
"testing"
)
func TestRequestCopyRace(t *testing.T) {
origReq := &http.Request{URL: &url.URL{}, Header: http.Header{}}
origReq.Header.Set("Header", "OrigValue")
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
req := copyHTTPRequest(origReq, ioutil.NopCloser(&bytes.Buffer{}))
req.Header.Set("Header", "Value")
go func() {
req2 := copyHTTPRequest(req, ioutil.NopCloser(&bytes.Buffer{}))
req2.Header.Add("Header", "Value2")
}()
_ = req.Header.Get("Header")
wg.Done()
}()
_ = origReq.Header.Get("Header")
}
origReq.Header.Get("Header")
wg.Wait()
}

View file

@ -0,0 +1,37 @@
// +build go1.5
package request_test
import (
"errors"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/mock"
"github.com/stretchr/testify/assert"
)
func TestRequestCancelRetry(t *testing.T) {
c := make(chan struct{})
reqNum := 0
s := mock.NewMockClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.Clear()
s.Handlers.UnmarshalMeta.Clear()
s.Handlers.UnmarshalError.Clear()
s.Handlers.Send.PushFront(func(r *request.Request) {
reqNum++
r.Error = errors.New("net/http: request canceled")
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
r.HTTPRequest.Cancel = c
close(c)
err := r.Send()
assert.True(t, strings.Contains(err.Error(), "canceled"))
assert.Equal(t, 1, reqNum)
}

View file

@ -0,0 +1,139 @@
package request
import (
"bytes"
"io"
"math/rand"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestOffsetReaderRead(t *testing.T) {
buf := []byte("testData")
reader := &offsetReader{buf: bytes.NewReader(buf)}
tempBuf := make([]byte, len(buf))
n, err := reader.Read(tempBuf)
assert.Equal(t, n, len(buf))
assert.Nil(t, err)
assert.Equal(t, buf, tempBuf)
}
func TestOffsetReaderSeek(t *testing.T) {
buf := []byte("testData")
reader := newOffsetReader(bytes.NewReader(buf), 0)
orig, err := reader.Seek(0, 1)
assert.NoError(t, err)
assert.Equal(t, int64(0), orig)
n, err := reader.Seek(0, 2)
assert.NoError(t, err)
assert.Equal(t, int64(len(buf)), n)
n, err = reader.Seek(orig, 0)
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
}
func TestOffsetReaderClose(t *testing.T) {
buf := []byte("testData")
reader := &offsetReader{buf: bytes.NewReader(buf)}
err := reader.Close()
assert.Nil(t, err)
tempBuf := make([]byte, len(buf))
n, err := reader.Read(tempBuf)
assert.Equal(t, n, 0)
assert.Equal(t, err, io.EOF)
}
func TestOffsetReaderCloseAndCopy(t *testing.T) {
buf := []byte("testData")
tempBuf := make([]byte, len(buf))
reader := &offsetReader{buf: bytes.NewReader(buf)}
newReader := reader.CloseAndCopy(0)
n, err := reader.Read(tempBuf)
assert.Equal(t, n, 0)
assert.Equal(t, err, io.EOF)
n, err = newReader.Read(tempBuf)
assert.Equal(t, n, len(buf))
assert.Nil(t, err)
assert.Equal(t, buf, tempBuf)
}
func TestOffsetReaderCloseAndCopyOffset(t *testing.T) {
buf := []byte("testData")
tempBuf := make([]byte, len(buf))
reader := &offsetReader{buf: bytes.NewReader(buf)}
newReader := reader.CloseAndCopy(4)
n, err := newReader.Read(tempBuf)
assert.Equal(t, n, len(buf)-4)
assert.Nil(t, err)
expected := []byte{'D', 'a', 't', 'a', 0, 0, 0, 0}
assert.Equal(t, expected, tempBuf)
}
func TestOffsetReaderRace(t *testing.T) {
wg := sync.WaitGroup{}
f := func(reader *offsetReader) {
defer wg.Done()
var err error
buf := make([]byte, 1)
_, err = reader.Read(buf)
for err != io.EOF {
_, err = reader.Read(buf)
}
}
closeFn := func(reader *offsetReader) {
defer wg.Done()
time.Sleep(time.Duration(rand.Intn(20)+1) * time.Millisecond)
reader.Close()
}
for i := 0; i < 50; i++ {
reader := &offsetReader{buf: bytes.NewReader(make([]byte, 1024*1024))}
wg.Add(1)
go f(reader)
wg.Add(1)
go closeFn(reader)
}
wg.Wait()
}
func BenchmarkOffsetReader(b *testing.B) {
bufSize := 1024 * 1024 * 100
buf := make([]byte, bufSize)
reader := &offsetReader{buf: bytes.NewReader(buf)}
tempBuf := make([]byte, 1024)
for i := 0; i < b.N; i++ {
reader.Read(tempBuf)
}
}
func BenchmarkBytesReader(b *testing.B) {
bufSize := 1024 * 1024 * 100
buf := make([]byte, bufSize)
reader := bytes.NewReader(buf)
tempBuf := make([]byte, 1024)
for i := 0; i < b.N; i++ {
reader.Read(tempBuf)
}
}

View file

@ -16,6 +16,24 @@ import (
"github.com/aws/aws-sdk-go/aws/client/metadata"
)
const (
// ErrCodeSerialization is the serialization error code that is received
// during protocol unmarshaling.
ErrCodeSerialization = "SerializationError"
// ErrCodeRead is an error that is returned during HTTP reads.
ErrCodeRead = "ReadError"
// ErrCodeResponseTimeout is the connection timeout error that is recieved
// during body reads.
ErrCodeResponseTimeout = "ResponseTimeout"
// CanceledErrorCode is the error code that will be returned by an
// API request that was canceled. Requests given a aws.Context may
// return this error when canceled.
CanceledErrorCode = "RequestCanceled"
)
// A Request is the service request to be made.
type Request struct {
Config aws.Config
@ -23,30 +41,33 @@ type Request struct {
Handlers Handlers
Retryer
Time time.Time
ExpireTime time.Duration
Operation *Operation
HTTPRequest *http.Request
HTTPResponse *http.Response
Body io.ReadSeeker
BodyStart int64 // offset from beginning of Body that the request body starts
Params interface{}
Error error
Data interface{}
RequestID string
RetryCount int
Retryable *bool
RetryDelay time.Duration
NotHoist bool
SignedHeaderVals http.Header
LastSignedAt time.Time
Time time.Time
ExpireTime time.Duration
Operation *Operation
HTTPRequest *http.Request
HTTPResponse *http.Response
Body io.ReadSeeker
BodyStart int64 // offset from beginning of Body that the request body starts
Params interface{}
Error error
Data interface{}
RequestID string
RetryCount int
Retryable *bool
RetryDelay time.Duration
NotHoist bool
SignedHeaderVals http.Header
LastSignedAt time.Time
DisableFollowRedirects bool
context aws.Context
built bool
// Need to persist an intermideant body betweend the input Body and HTTP
// Need to persist an intermediate body between the input Body and HTTP
// request body because the HTTP Client's transport can maintain a reference
// to the HTTP request's body after the client has returned. This value is
// safe to use concurrently and rewraps the input Body for each HTTP request.
// safe to use concurrently and wrap the input Body for each HTTP request.
safeBody *offsetReader
}
@ -60,14 +81,6 @@ type Operation struct {
BeforePresignFn func(r *Request) error
}
// Paginator keeps track of pagination configuration for an API operation.
type Paginator struct {
InputTokens []string
OutputTokens []string
LimitToken string
TruncationToken string
}
// New returns a new Request pointer for the service API
// operation and parameters.
//
@ -111,6 +124,94 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
return r
}
// A Option is a functional option that can augment or modify a request when
// using a WithContext API operation method.
type Option func(*Request)
// WithGetResponseHeader builds a request Option which will retrieve a single
// header value from the HTTP Response. If there are multiple values for the
// header key use WithGetResponseHeaders instead to access the http.Header
// map directly. The passed in val pointer must be non-nil.
//
// This Option can be used multiple times with a single API operation.
//
// var id2, versionID string
// svc.PutObjectWithContext(ctx, params,
// request.WithGetResponseHeader("x-amz-id-2", &id2),
// request.WithGetResponseHeader("x-amz-version-id", &versionID),
// )
func WithGetResponseHeader(key string, val *string) Option {
return func(r *Request) {
r.Handlers.Complete.PushBack(func(req *Request) {
*val = req.HTTPResponse.Header.Get(key)
})
}
}
// WithGetResponseHeaders builds a request Option which will retrieve the
// headers from the HTTP response and assign them to the passed in headers
// variable. The passed in headers pointer must be non-nil.
//
// var headers http.Header
// svc.PutObjectWithContext(ctx, params, request.WithGetResponseHeaders(&headers))
func WithGetResponseHeaders(headers *http.Header) Option {
return func(r *Request) {
r.Handlers.Complete.PushBack(func(req *Request) {
*headers = req.HTTPResponse.Header
})
}
}
// WithLogLevel is a request option that will set the request to use a specific
// log level when the request is made.
//
// svc.PutObjectWithContext(ctx, params, request.WithLogLevel(aws.LogDebugWithHTTPBody)
func WithLogLevel(l aws.LogLevelType) Option {
return func(r *Request) {
r.Config.LogLevel = aws.LogLevel(l)
}
}
// ApplyOptions will apply each option to the request calling them in the order
// the were provided.
func (r *Request) ApplyOptions(opts ...Option) {
for _, opt := range opts {
opt(r)
}
}
// Context will always returns a non-nil context. If Request does not have a
// context aws.BackgroundContext will be returned.
func (r *Request) Context() aws.Context {
if r.context != nil {
return r.context
}
return aws.BackgroundContext()
}
// SetContext adds a Context to the current request that can be used to cancel
// a in-flight request. The Context value must not be nil, or this method will
// panic.
//
// Unlike http.Request.WithContext, SetContext does not return a copy of the
// Request. It is not safe to use use a single Request value for multiple
// requests. A new Request should be created for each API operation request.
//
// Go 1.6 and below:
// The http.Request's Cancel field will be set to the Done() value of
// the context. This will overwrite the Cancel field's value.
//
// Go 1.7 and above:
// The http.Request.WithContext will be used to set the context on the underlying
// http.Request. This will create a shallow copy of the http.Request. The SDK
// may create sub contexts in the future for nested requests such as retries.
func (r *Request) SetContext(ctx aws.Context) {
if ctx == nil {
panic("context cannot be nil")
}
setRequestContext(r, ctx)
}
// WillRetry returns if the request's can be retried.
func (r *Request) WillRetry() bool {
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
@ -262,7 +363,7 @@ func (r *Request) ResetBody() {
// Related golang/go#18257
l, err := computeBodyLength(r.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to compute request body size", err)
r.Error = awserr.New(ErrCodeSerialization, "failed to compute request body size", err)
return
}
@ -344,6 +445,12 @@ func (r *Request) GetBody() io.ReadSeeker {
//
// Send will not close the request.Request's body.
func (r *Request) Send() error {
defer func() {
// Regardless of success or failure of the request trigger the Complete
// request handlers.
r.Handlers.Complete.Run(r)
}()
for {
if aws.BoolValue(r.Retryable) {
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
@ -446,6 +553,9 @@ func shouldRetryCancel(r *Request) bool {
timeoutErr := false
errStr := r.Error.Error()
if ok {
if awsErr.Code() == CanceledErrorCode {
return false
}
err := awsErr.OrigErr()
netErr, netOK := err.(net.Error)
timeoutErr = netOK && netErr.Temporary()

View file

@ -0,0 +1,11 @@
// +build !go1.6
package request_test
import (
"errors"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var errTimeout = awserr.New("foo", "bar", errors.New("net/http: request canceled Timeout"))

View file

@ -0,0 +1,51 @@
// +build go1.6
package request_test
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
)
// go version 1.4 and 1.5 do not return an error. Version 1.5 will url encode
// the uri while 1.4 will not
func TestRequestInvalidEndpoint(t *testing.T) {
endpoint := "http://localhost:90 "
r := request.New(
aws.Config{},
metadata.ClientInfo{Endpoint: endpoint},
defaults.Handlers(),
client.DefaultRetryer{},
&request.Operation{},
nil,
nil,
)
assert.Error(t, r.Error)
}
type timeoutErr struct {
error
}
var errTimeout = awserr.New("foo", "bar", &timeoutErr{
errors.New("net/http: request canceled"),
})
func (e *timeoutErr) Timeout() bool {
return true
}
func (e *timeoutErr) Temporary() bool {
return true
}

View file

@ -0,0 +1,24 @@
// +build !go1.8
package request
import (
"net/http"
"strings"
"testing"
)
func TestResetBody_WithEmptyBody(t *testing.T) {
r := Request{
HTTPRequest: &http.Request{},
}
reader := strings.NewReader("")
r.Body = reader
r.ResetBody()
if a, e := r.HTTPRequest.Body, (noBody{}); a != e {
t.Errorf("expected request body to be set to reader, got %#v", r.HTTPRequest.Body)
}
}

View file

@ -0,0 +1,25 @@
// +build go1.8
package request
import (
"net/http"
"strings"
"testing"
)
func TestResetBody_WithEmptyBody(t *testing.T) {
r := Request{
HTTPRequest: &http.Request{},
}
reader := strings.NewReader("")
r.Body = reader
r.ResetBody()
if a, e := r.HTTPRequest.Body, http.NoBody; a != e {
t.Errorf("expected request body to be set to reader, got %#v",
r.HTTPRequest.Body)
}
}

View file

@ -0,0 +1,14 @@
// +build go1.7
package request
import "github.com/aws/aws-sdk-go/aws"
// setContext updates the Request to use the passed in context for cancellation.
// Context will also be used for request retry delay.
//
// Creates shallow copy of the http.Request with the WithContext method.
func setRequestContext(r *Request, ctx aws.Context) {
r.context = ctx
r.HTTPRequest = r.HTTPRequest.WithContext(ctx)
}

View file

@ -0,0 +1,14 @@
// +build !go1.7
package request
import "github.com/aws/aws-sdk-go/aws"
// setContext updates the Request to use the passed in context for cancellation.
// Context will also be used for request retry delay.
//
// Creates shallow copy of the http.Request with the WithContext method.
func setRequestContext(r *Request, ctx aws.Context) {
r.context = ctx
r.HTTPRequest.Cancel = ctx.Done()
}

View file

@ -0,0 +1,46 @@
package request_test
import (
"fmt"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestRequest_SetContext(t *testing.T) {
svc := awstesting.NewClient()
svc.Handlers.Clear()
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
r.SetContext(ctx)
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
err := r.Send()
if err == nil {
t.Fatalf("expected error, got none")
}
// Only check against canceled because go 1.6 will not use the context's
// Err().
if e, a := "canceled", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %q to be in %q, but was not", e, a)
}
}
func TestRequest_SetContextPanic(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatalf("expect SetContext to panic, did not")
}
}()
r := &request.Request{}
r.SetContext(nil)
}

View file

@ -0,0 +1,27 @@
package request
import (
"testing"
)
func TestCopy(t *testing.T) {
handlers := Handlers{}
op := &Operation{}
op.HTTPMethod = "Foo"
req := &Request{}
req.Operation = op
req.Handlers = handlers
r := req.copy()
if r == req {
t.Fatal("expect request pointer copy to be different")
}
if r.Operation == req.Operation {
t.Errorf("expect request operation pointer to be different")
}
if e, a := req.Operation.HTTPMethod, r.Operation.HTTPMethod; e != a {
t.Errorf("expect %q http method, got %q", e, a)
}
}

View file

@ -2,29 +2,125 @@ package request
import (
"reflect"
"sync/atomic"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
//type Paginater interface {
// HasNextPage() bool
// NextPage() *Request
// EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error
//}
// A Pagination provides paginating of SDK API operations which are paginatable.
// Generally you should not use this type directly, but use the "Pages" API
// operations method to automatically perform pagination for you. Such as,
// "S3.ListObjectsPages", and "S3.ListObjectsPagesWithContext" methods.
//
// Pagination differs from a Paginator type in that pagination is the type that
// does the pagination between API operations, and Paginator defines the
// configuration that will be used per page request.
//
// cont := true
// for p.Next() && cont {
// data := p.Page().(*s3.ListObjectsOutput)
// // process the page's data
// }
// return p.Err()
//
// See service client API operation Pages methods for examples how the SDK will
// use the Pagination type.
type Pagination struct {
// Function to return a Request value for each pagination request.
// Any configuration or handlers that need to be applied to the request
// prior to getting the next page should be done here before the request
// returned.
//
// NewRequest should always be built from the same API operations. It is
// undefined if different API operations are returned on subsequent calls.
NewRequest func() (*Request, error)
// HasNextPage returns true if this request has more pages of data available.
func (r *Request) HasNextPage() bool {
return len(r.nextPageTokens()) > 0
started bool
nextTokens []interface{}
err error
curPage interface{}
}
// nextPageTokens returns the tokens to use when asking for the next page of
// data.
// HasNextPage will return true if Pagination is able to determine that the API
// operation has additional pages. False will be returned if there are no more
// pages remaining.
//
// Will always return true if Next has not been called yet.
func (p *Pagination) HasNextPage() bool {
return !(p.started && len(p.nextTokens) == 0)
}
// Err returns the error Pagination encountered when retrieving the next page.
func (p *Pagination) Err() error {
return p.err
}
// Page returns the current page. Page should only be called after a successful
// call to Next. It is undefined what Page will return if Page is called after
// Next returns false.
func (p *Pagination) Page() interface{} {
return p.curPage
}
// Next will attempt to retrieve the next page for the API operation. When a page
// is retrieved true will be returned. If the page cannot be retrieved, or there
// are no more pages false will be returned.
//
// Use the Page method to retrieve the current page data. The data will need
// to be cast to the API operation's output type.
//
// Use the Err method to determine if an error occurred if Page returns false.
func (p *Pagination) Next() bool {
if !p.HasNextPage() {
return false
}
req, err := p.NewRequest()
if err != nil {
p.err = err
return false
}
if p.started {
for i, intok := range req.Operation.InputTokens {
awsutil.SetValueAtPath(req.Params, intok, p.nextTokens[i])
}
}
p.started = true
err = req.Send()
if err != nil {
p.err = err
return false
}
p.nextTokens = req.nextPageTokens()
p.curPage = req.Data
return true
}
// A Paginator is the configuration data that defines how an API operation
// should be paginated. This type is used by the API service models to define
// the generated pagination config for service APIs.
//
// The Pagination type is what provides iterating between pages of an API. It
// is only used to store the token metadata the SDK should use for performing
// pagination.
type Paginator struct {
InputTokens []string
OutputTokens []string
LimitToken string
TruncationToken string
}
// nextPageTokens returns the tokens to use when asking for the next page of data.
func (r *Request) nextPageTokens() []interface{} {
if r.Operation.Paginator == nil {
return nil
}
if r.Operation.TruncationToken != "" {
tr, _ := awsutil.ValuesAtPath(r.Data, r.Operation.TruncationToken)
if len(tr) == 0 {
@ -61,9 +157,40 @@ func (r *Request) nextPageTokens() []interface{} {
return tokens
}
// Ensure a deprecated item is only logged once instead of each time its used.
func logDeprecatedf(logger aws.Logger, flag *int32, msg string) {
if logger == nil {
return
}
if atomic.CompareAndSwapInt32(flag, 0, 1) {
logger.Log(msg)
}
}
var (
logDeprecatedHasNextPage int32
logDeprecatedNextPage int32
logDeprecatedEachPage int32
)
// HasNextPage returns true if this request has more pages of data available.
//
// Deprecated Use Pagination type for configurable pagination of API operations
func (r *Request) HasNextPage() bool {
logDeprecatedf(r.Config.Logger, &logDeprecatedHasNextPage,
"Request.HasNextPage deprecated. Use Pagination type for configurable pagination of API operations")
return len(r.nextPageTokens()) > 0
}
// NextPage returns a new Request that can be executed to return the next
// page of result data. Call .Send() on this request to execute it.
//
// Deprecated Use Pagination type for configurable pagination of API operations
func (r *Request) NextPage() *Request {
logDeprecatedf(r.Config.Logger, &logDeprecatedNextPage,
"Request.NextPage deprecated. Use Pagination type for configurable pagination of API operations")
tokens := r.nextPageTokens()
if len(tokens) == 0 {
return nil
@ -90,7 +217,12 @@ func (r *Request) NextPage() *Request {
// as the structure "T". The lastPage value represents whether the page is
// the last page of data or not. The return value of this function should
// return true to keep iterating or false to stop.
//
// Deprecated Use Pagination type for configurable pagination of API operations
func (r *Request) EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error {
logDeprecatedf(r.Config.Logger, &logDeprecatedEachPage,
"Request.EachPage deprecated. Use Pagination type for configurable pagination of API operations")
for page := r; page != nil; page = page.NextPage() {
if err := page.Send(); err != nil {
return err

View file

@ -0,0 +1,604 @@
package request_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/route53"
"github.com/aws/aws-sdk-go/service/s3"
)
// Use DynamoDB methods for simplicity
func TestPaginationQueryPage(t *testing.T) {
db := dynamodb.New(unit.Session)
tokens, pages, numPages, gotToEnd := []map[string]*dynamodb.AttributeValue{}, []map[string]*dynamodb.AttributeValue{}, 0, false
reqNum := 0
resps := []*dynamodb.QueryOutput{
{
LastEvaluatedKey: map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key1")}},
Count: aws.Int64(1),
Items: []map[string]*dynamodb.AttributeValue{
{
"key": {S: aws.String("key1")},
},
},
},
{
LastEvaluatedKey: map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key2")}},
Count: aws.Int64(1),
Items: []map[string]*dynamodb.AttributeValue{
{
"key": {S: aws.String("key2")},
},
},
},
{
LastEvaluatedKey: map[string]*dynamodb.AttributeValue{},
Count: aws.Int64(1),
Items: []map[string]*dynamodb.AttributeValue{
{
"key": {S: aws.String("key3")},
},
},
},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.QueryInput)
if in == nil {
tokens = append(tokens, nil)
} else if len(in.ExclusiveStartKey) != 0 {
tokens = append(tokens, in.ExclusiveStartKey)
}
})
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.QueryInput{
Limit: aws.Int64(2),
TableName: aws.String("tablename"),
}
err := db.QueryPages(params, func(p *dynamodb.QueryOutput, last bool) bool {
numPages++
for _, item := range p.Items {
pages = append(pages, item)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Nil(t, err)
assert.Equal(t,
[]map[string]*dynamodb.AttributeValue{
{"key": {S: aws.String("key1")}},
{"key": {S: aws.String("key2")}},
}, tokens)
assert.Equal(t,
[]map[string]*dynamodb.AttributeValue{
{"key": {S: aws.String("key1")}},
{"key": {S: aws.String("key2")}},
{"key": {S: aws.String("key3")}},
}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, params.ExclusiveStartKey)
}
// Use DynamoDB methods for simplicity
func TestPagination(t *testing.T) {
db := dynamodb.New(unit.Session)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
for _, t := range p.TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
assert.Nil(t, params.ExclusiveStartTableName)
}
// Use DynamoDB methods for simplicity
func TestPaginationEachPage(t *testing.T) {
db := dynamodb.New(unit.Session)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
req, _ := db.ListTablesRequest(params)
err := req.EachPage(func(p interface{}, last bool) bool {
numPages++
for _, t := range p.(*dynamodb.ListTablesOutput).TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
}
// Use DynamoDB methods for simplicity
func TestPaginationEarlyExit(t *testing.T) {
db := dynamodb.New(unit.Session)
numPages, gotToEnd := 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
if numPages == 2 {
return false
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, 2, numPages)
assert.False(t, gotToEnd)
assert.Nil(t, err)
}
func TestSkipPagination(t *testing.T) {
client := s3.New(unit.Session)
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = &s3.HeadBucketOutput{}
})
req, _ := client.HeadBucketRequest(&s3.HeadBucketInput{Bucket: aws.String("bucket")})
numPages, gotToEnd := 0, false
req.EachPage(func(p interface{}, last bool) bool {
numPages++
if last {
gotToEnd = true
}
return true
})
assert.Equal(t, 1, numPages)
assert.True(t, gotToEnd)
}
// Use S3 for simplicity
func TestPaginationTruncation(t *testing.T) {
client := s3.New(unit.Session)
reqNum := 0
resps := []*s3.ListObjectsOutput{
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key1")}}},
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key2")}}},
{IsTruncated: aws.Bool(false), Contents: []*s3.Object{{Key: aws.String("Key3")}}},
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key4")}}},
}
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &s3.ListObjectsInput{Bucket: aws.String("bucket")}
results := []string{}
err := client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2", "Key3"}, results)
assert.Nil(t, err)
// Try again without truncation token at all
reqNum = 0
resps[1].IsTruncated = nil
resps[2].IsTruncated = aws.Bool(true)
results = []string{}
err = client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2"}, results)
assert.Nil(t, err)
}
func TestPaginationNilToken(t *testing.T) {
client := route53.New(unit.Session)
reqNum := 0
resps := []*route53.ListResourceRecordSetsOutput{
{
ResourceRecordSets: []*route53.ResourceRecordSet{
{Name: aws.String("first.example.com.")},
},
IsTruncated: aws.Bool(true),
NextRecordName: aws.String("second.example.com."),
NextRecordType: aws.String("MX"),
NextRecordIdentifier: aws.String("second"),
MaxItems: aws.String("1"),
},
{
ResourceRecordSets: []*route53.ResourceRecordSet{
{Name: aws.String("second.example.com.")},
},
IsTruncated: aws.Bool(true),
NextRecordName: aws.String("third.example.com."),
NextRecordType: aws.String("MX"),
MaxItems: aws.String("1"),
},
{
ResourceRecordSets: []*route53.ResourceRecordSet{
{Name: aws.String("third.example.com.")},
},
IsTruncated: aws.Bool(false),
MaxItems: aws.String("1"),
},
}
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
idents := []string{}
client.Handlers.Build.PushBack(func(r *request.Request) {
p := r.Params.(*route53.ListResourceRecordSetsInput)
idents = append(idents, aws.StringValue(p.StartRecordIdentifier))
})
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &route53.ListResourceRecordSetsInput{
HostedZoneId: aws.String("id-zone"),
}
results := []string{}
err := client.ListResourceRecordSetsPages(params, func(p *route53.ListResourceRecordSetsOutput, last bool) bool {
results = append(results, *p.ResourceRecordSets[0].Name)
return true
})
assert.NoError(t, err)
assert.Equal(t, []string{"", "second", ""}, idents)
assert.Equal(t, []string{"first.example.com.", "second.example.com.", "third.example.com."}, results)
}
func TestPaginationNilInput(t *testing.T) {
// Code generation doesn't have a great way to verify the code is correct
// other than being run via unit tests in the SDK. This should be fixed
// So code generation can be validated independently.
client := s3.New(unit.Session)
client.Handlers.Validate.Clear()
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = &s3.ListObjectsOutput{}
})
gotToEnd := false
numPages := 0
err := client.ListObjectsPages(nil, func(p *s3.ListObjectsOutput, last bool) bool {
numPages++
if last {
gotToEnd = true
}
return true
})
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
if e, a := 1, numPages; e != a {
t.Errorf("expect %d number pages but got %d", e, a)
}
if !gotToEnd {
t.Errorf("expect to of gotten to end, did not")
}
}
func TestPaginationWithContextNilInput(t *testing.T) {
// Code generation doesn't have a great way to verify the code is correct
// other than being run via unit tests in the SDK. This should be fixed
// So code generation can be validated independently.
client := s3.New(unit.Session)
client.Handlers.Validate.Clear()
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = &s3.ListObjectsOutput{}
})
gotToEnd := false
numPages := 0
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
err := client.ListObjectsPagesWithContext(ctx, nil, func(p *s3.ListObjectsOutput, last bool) bool {
numPages++
if last {
gotToEnd = true
}
return true
})
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
if e, a := 1, numPages; e != a {
t.Errorf("expect %d number pages but got %d", e, a)
}
if !gotToEnd {
t.Errorf("expect to of gotten to end, did not")
}
}
type testPageInput struct {
NextToken string
}
type testPageOutput struct {
Value string
NextToken *string
}
func TestPagination_Standalone(t *testing.T) {
expect := []struct {
Value, PrevToken, NextToken string
}{
{"FirstValue", "InitalToken", "FirstToken"},
{"SecondValue", "FirstToken", "SecondToken"},
{"ThirdValue", "SecondToken", ""},
}
input := testPageInput{
NextToken: expect[0].PrevToken,
}
c := awstesting.NewClient()
i := 0
p := request.Pagination{
NewRequest: func() (*request.Request, error) {
r := c.NewRequest(
&request.Operation{
Name: "Operation",
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
},
},
&input, &testPageOutput{},
)
// Setup handlers for testing
r.Handlers.Clear()
r.Handlers.Build.PushBack(func(req *request.Request) {
in := req.Params.(*testPageInput)
if e, a := expect[i].PrevToken, in.NextToken; e != a {
t.Errorf("%d, expect NextToken input %q, got %q", i, e, a)
}
})
r.Handlers.Unmarshal.PushBack(func(req *request.Request) {
out := &testPageOutput{
Value: expect[i].Value,
}
if len(expect[i].NextToken) > 0 {
out.NextToken = aws.String(expect[i].NextToken)
}
req.Data = out
})
return r, nil
},
}
for p.Next() {
data := p.Page().(*testPageOutput)
if e, a := expect[i].Value, data.Value; e != a {
t.Errorf("%d, expect Value to be %q, got %q", i, e, a)
}
if e, a := expect[i].NextToken, aws.StringValue(data.NextToken); e != a {
t.Errorf("%d, expect NextToken to be %q, got %q", i, e, a)
}
i++
}
if e, a := len(expect), i; e != a {
t.Errorf("expected to process %d pages, did %d", e, a)
}
if err := p.Err(); err != nil {
t.Fatalf("%d, expected no error, got %v", i, err)
}
}
// Benchmarks
var benchResps = []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE")}},
}
var benchDb = func() *dynamodb.DynamoDB {
db := dynamodb.New(unit.Session)
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
return db
}
func BenchmarkCodegenIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
iter := func(fn func(*dynamodb.ListTablesOutput, bool) bool) error {
page, _ := db.ListTablesRequest(input)
for ; page != nil; page = page.NextPage() {
page.Send()
out := page.Data.(*dynamodb.ListTablesOutput)
if result := fn(out, !page.HasNextPage()); page.Error != nil || !result {
return page.Error
}
}
return nil
}
for i := 0; i < b.N; i++ {
reqNum = 0
iter(func(p *dynamodb.ListTablesOutput, last bool) bool {
return true
})
}
}
func BenchmarkEachPageIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
for i := 0; i < b.N; i++ {
reqNum = 0
req, _ := db.ListTablesRequest(input)
req.EachPage(func(p interface{}, last bool) bool {
return true
})
}
}

View file

@ -0,0 +1,58 @@
package request
import (
"bytes"
"net/http"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
)
func TestResetBody_WithBodyContents(t *testing.T) {
r := Request{
HTTPRequest: &http.Request{},
}
reader := strings.NewReader("abc")
r.Body = reader
r.ResetBody()
if v, ok := r.HTTPRequest.Body.(*offsetReader); !ok || v == nil {
t.Errorf("expected request body to be set to reader, got %#v",
r.HTTPRequest.Body)
}
}
func TestResetBody_ExcludeUnseekableBodyByMethod(t *testing.T) {
cases := []struct {
Method string
IsNoBody bool
}{
{"GET", true},
{"HEAD", true},
{"DELETE", true},
{"PUT", false},
{"PATCH", false},
{"POST", false},
}
reader := aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc")))
for i, c := range cases {
r := Request{
HTTPRequest: &http.Request{},
Operation: &Operation{
HTTPMethod: c.Method,
},
}
r.SetReaderBody(reader)
if a, e := r.HTTPRequest.Body == noBodyReader, c.IsNoBody; a != e {
t.Errorf("%d, expect body to be set to noBody(%t), but was %t", i, e, a)
}
}
}

View file

@ -0,0 +1,762 @@
package request_test
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"runtime"
"strconv"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
"github.com/aws/aws-sdk-go/private/protocol/rest"
)
type testData struct {
Data string
}
func body(str string) io.ReadCloser {
return ioutil.NopCloser(bytes.NewReader([]byte(str)))
}
func unmarshal(req *request.Request) {
defer req.HTTPResponse.Body.Close()
if req.Data != nil {
json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data)
}
return
}
func unmarshalError(req *request.Request) {
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err)
return
}
if len(bodyBytes) == 0 {
req.Error = awserr.NewRequestFailure(
awserr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")),
req.HTTPResponse.StatusCode,
"",
)
return
}
var jsonErr jsonErrorResponse
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
req.Error = awserr.New("UnmarshaleError", "JSON unmarshal", err)
return
}
req.Error = awserr.NewRequestFailure(
awserr.New(jsonErr.Code, jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
"",
)
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}
// test that retries occur for 5xx status codes
func TestRequestRecoverRetry5xx(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
if e, a := 2, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
if e, a := "valid", out.Data; e != a {
t.Errorf("expect %q output got %q", e, a)
}
}
// test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry`
func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)},
{StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
if e, a := 2, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
if e, a := "valid", out.Data; e != a {
t.Errorf("expect %q output got %q", e, a)
}
}
// test that retries don't occur for 4xx status codes with a response type that can't be retried
func TestRequest4xxUnretryable(t *testing.T) {
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)}
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
if err == nil {
t.Fatalf("expect error, but did not get one")
}
aerr := err.(awserr.RequestFailure)
if e, a := 401, aerr.StatusCode(); e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
if e, a := "SignatureDoesNotMatch", aerr.Code(); e != a {
t.Errorf("expect %q error code, got %q", e, a)
}
if e, a := "Signature does not match.", aerr.Message(); e != a {
t.Errorf("expect %q error message, got %q", e, a)
}
if e, a := 0, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
}
func TestRequestExhaustRetries(t *testing.T) {
delays := []time.Duration{}
sleepDelay := func(delay time.Duration) {
delays = append(delays, delay)
}
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
}
s := awstesting.NewClient(aws.NewConfig().WithSleepDelay(sleepDelay))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := r.Send()
if err == nil {
t.Fatalf("expect error, but did not get one")
}
aerr := err.(awserr.RequestFailure)
if e, a := 500, aerr.StatusCode(); e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
if e, a := "UnknownError", aerr.Code(); e != a {
t.Errorf("expect %q error code, got %q", e, a)
}
if e, a := "An error occurred.", aerr.Message(); e != a {
t.Errorf("expect %q error message, got %q", e, a)
}
if e, a := 3, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
expectDelays := []struct{ min, max time.Duration }{{30, 59}, {60, 118}, {120, 236}}
for i, v := range delays {
min := expectDelays[i].min * time.Millisecond
max := expectDelays[i].max * time.Millisecond
if !(min <= v && v <= max) {
t.Errorf("Expect delay to be within range, i:%d, v:%s, min:%s, max:%s",
i, v, min, max)
}
}
}
// test that the request is retried after the credentials are expired.
func TestRequestRecoverExpiredCreds(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := awstesting.NewClient(&aws.Config{MaxRetries: aws.Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
credExpiredBeforeRetry := false
credExpiredAfterRetry := false
s.Handlers.AfterRetry.PushBack(func(r *request.Request) {
credExpiredAfterRetry = r.Config.Credentials.IsExpired()
})
s.Handlers.Sign.Clear()
s.Handlers.Sign.PushBack(func(r *request.Request) {
r.Config.Credentials.Get()
})
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if credExpiredBeforeRetry {
t.Errorf("Expect valid creds before retry check")
}
if !credExpiredAfterRetry {
t.Errorf("Expect expired creds after retry check")
}
if s.Config.Credentials.IsExpired() {
t.Errorf("Expect valid creds after cred expired recovery")
}
if e, a := 1, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
if e, a := "valid", out.Data; e != a {
t.Errorf("expect %q output got %q", e, a)
}
}
func TestMakeAddtoUserAgentHandler(t *testing.T) {
fn := request.MakeAddToUserAgentHandler("name", "version", "extra1", "extra2")
r := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}}
r.HTTPRequest.Header.Set("User-Agent", "foo/bar")
fn(r)
if e, a := "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"); e != a {
t.Errorf("expect %q user agent, got %q", e, a)
}
}
func TestMakeAddtoUserAgentFreeFormHandler(t *testing.T) {
fn := request.MakeAddToUserAgentFreeFormHandler("name/version (extra1; extra2)")
r := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}}
r.HTTPRequest.Header.Set("User-Agent", "foo/bar")
fn(r)
if e, a := "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"); e != a {
t.Errorf("expect %q user agent, got %q", e, a)
}
}
func TestRequestUserAgent(t *testing.T) {
s := awstesting.NewClient(&aws.Config{Region: aws.String("us-east-1")})
// s.Handlers.Validate.Clear()
req := s.NewRequest(&request.Operation{Name: "Operation"}, nil, &testData{})
req.HTTPRequest.Header.Set("User-Agent", "foo/bar")
if err := req.Build(); err != nil {
t.Fatalf("expect no error, got %v", err)
}
expectUA := fmt.Sprintf("foo/bar %s/%s (%s; %s; %s)",
aws.SDKName, aws.SDKVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH)
if e, a := expectUA, req.HTTPRequest.Header.Get("User-Agent"); e != a {
t.Errorf("expect %q user agent, got %q", e, a)
}
}
func TestRequestThrottleRetries(t *testing.T) {
delays := []time.Duration{}
sleepDelay := func(delay time.Duration) {
delays = append(delays, delay)
}
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)},
}
s := awstesting.NewClient(aws.NewConfig().WithSleepDelay(sleepDelay))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := r.Send()
if err == nil {
t.Fatalf("expect error, but did not get one")
}
aerr := err.(awserr.RequestFailure)
if e, a := 500, aerr.StatusCode(); e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
if e, a := "Throttling", aerr.Code(); e != a {
t.Errorf("expect %q error code, got %q", e, a)
}
if e, a := "An error occurred.", aerr.Message(); e != a {
t.Errorf("expect %q error message, got %q", e, a)
}
if e, a := 3, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
expectDelays := []struct{ min, max time.Duration }{{500, 999}, {1000, 1998}, {2000, 3996}}
for i, v := range delays {
min := expectDelays[i].min * time.Millisecond
max := expectDelays[i].max * time.Millisecond
if !(min <= v && v <= max) {
t.Errorf("Expect delay to be within range, i:%d, v:%s, min:%s, max:%s",
i, v, min, max)
}
}
}
// test that retries occur for request timeouts when response.Body can be nil
func TestRequestRecoverTimeoutWithNilBody(t *testing.T) {
reqNum := 0
reqs := []*http.Response{
{StatusCode: 0, Body: nil}, // body can be nil when requests time out
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
errors := []error{
errTimeout, nil,
}
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.AfterRetry.Clear() // force retry on all errors
s.Handlers.AfterRetry.PushBack(func(r *request.Request) {
if r.Error != nil {
r.Error = nil
r.Retryable = aws.Bool(true)
r.RetryCount++
}
})
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = reqs[reqNum]
r.Error = errors[reqNum]
reqNum++
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
if e, a := 1, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
if e, a := "valid", out.Data; e != a {
t.Errorf("expect %q output got %q", e, a)
}
}
func TestRequestRecoverTimeoutWithNilResponse(t *testing.T) {
reqNum := 0
reqs := []*http.Response{
nil,
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
errors := []error{
errTimeout,
nil,
}
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.AfterRetry.Clear() // force retry on all errors
s.Handlers.AfterRetry.PushBack(func(r *request.Request) {
if r.Error != nil {
r.Error = nil
r.Retryable = aws.Bool(true)
r.RetryCount++
}
})
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = reqs[reqNum]
r.Error = errors[reqNum]
reqNum++
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
if e, a := 1, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
if e, a := "valid", out.Data; e != a {
t.Errorf("expect %q output got %q", e, a)
}
}
func TestRequest_NoBody(t *testing.T) {
cases := []string{
"GET", "HEAD", "DELETE",
"PUT", "POST", "PATCH",
}
for i, c := range cases {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if v := r.TransferEncoding; len(v) > 0 {
t.Errorf("%d, expect no body sent with Transfer-Encoding, %v", i, v)
}
outMsg := []byte(`{"Value": "abc"}`)
if b, err := ioutil.ReadAll(r.Body); err != nil {
t.Fatalf("%d, expect no error reading request body, got %v", i, err)
} else if n := len(b); n > 0 {
t.Errorf("%d, expect no request body, got %d bytes", i, n)
}
w.Header().Set("Content-Length", strconv.Itoa(len(outMsg)))
if _, err := w.Write(outMsg); err != nil {
t.Fatalf("%d, expect no error writing server response, got %v", i, err)
}
}))
s := awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
MaxRetries: aws.Int(0),
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
})
s.Handlers.Build.PushBack(rest.Build)
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
in := struct {
Bucket *string `location:"uri" locationName:"bucket"`
Key *string `location:"uri" locationName:"key"`
}{
Bucket: aws.String("mybucket"), Key: aws.String("myKey"),
}
out := struct {
Value *string
}{}
r := s.NewRequest(&request.Operation{
Name: "OpName", HTTPMethod: c, HTTPPath: "/{bucket}/{key+}",
}, &in, &out)
if err := r.Send(); err != nil {
t.Fatalf("%d, expect no error sending request, got %v", i, err)
}
}
}
func TestIsSerializationErrorRetryable(t *testing.T) {
testCases := []struct {
err error
expected bool
}{
{
err: awserr.New(request.ErrCodeSerialization, "foo error", nil),
expected: false,
},
{
err: awserr.New("ErrFoo", "foo error", nil),
expected: false,
},
{
err: nil,
expected: false,
},
{
err: awserr.New(request.ErrCodeSerialization, "foo error", stubConnectionResetError),
expected: true,
},
}
for i, c := range testCases {
r := &request.Request{
Error: c.err,
}
if r.IsErrorRetryable() != c.expected {
t.Errorf("Case %d: Expected %v, but received %v", i+1, c.expected, !c.expected)
}
}
}
func TestWithLogLevel(t *testing.T) {
r := &request.Request{}
opt := request.WithLogLevel(aws.LogDebugWithHTTPBody)
r.ApplyOptions(opt)
if !r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) {
t.Errorf("expect log level to be set, but was not, %v",
r.Config.LogLevel.Value())
}
}
func TestWithGetResponseHeader(t *testing.T) {
r := &request.Request{}
var val, val2 string
r.ApplyOptions(
request.WithGetResponseHeader("x-a-header", &val),
request.WithGetResponseHeader("x-second-header", &val2),
)
r.HTTPResponse = &http.Response{
Header: func() http.Header {
h := http.Header{}
h.Set("x-a-header", "first")
h.Set("x-second-header", "second")
return h
}(),
}
r.Handlers.Complete.Run(r)
if e, a := "first", val; e != a {
t.Errorf("expect %q header value got %q", e, a)
}
if e, a := "second", val2; e != a {
t.Errorf("expect %q header value got %q", e, a)
}
}
func TestWithGetResponseHeaders(t *testing.T) {
r := &request.Request{}
var headers http.Header
opt := request.WithGetResponseHeaders(&headers)
r.ApplyOptions(opt)
r.HTTPResponse = &http.Response{
Header: func() http.Header {
h := http.Header{}
h.Set("x-a-header", "headerValue")
return h
}(),
}
r.Handlers.Complete.Run(r)
if e, a := "headerValue", headers.Get("x-a-header"); e != a {
t.Errorf("expect %q header value got %q", e, a)
}
}
type connResetCloser struct {
}
func (rc *connResetCloser) Read(b []byte) (int, error) {
return 0, stubConnectionResetError
}
func (rc *connResetCloser) Close() error {
return nil
}
func TestSerializationErrConnectionReset(t *testing.T) {
count := 0
handlers := request.Handlers{}
handlers.Send.PushBack(func(r *request.Request) {
count++
r.HTTPResponse = &http.Response{}
r.HTTPResponse.Body = &connResetCloser{}
})
handlers.Sign.PushBackNamed(v4.SignRequestHandler)
handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
op := &request.Operation{
Name: "op",
HTTPMethod: "POST",
HTTPPath: "/",
}
meta := metadata.ClientInfo{
ServiceName: "fooService",
SigningName: "foo",
SigningRegion: "foo",
Endpoint: "localhost",
APIVersion: "2001-01-01",
JSONVersion: "1.1",
TargetPrefix: "Foo",
}
cfg := unit.Session.Config.Copy()
cfg.MaxRetries = aws.Int(5)
req := request.New(
*cfg,
meta,
handlers,
client.DefaultRetryer{NumMaxRetries: 5},
op,
&struct {
}{},
&struct {
}{},
)
osErr := stubConnectionResetError
req.ApplyOptions(request.WithResponseReadTimeout(time.Second))
err := req.Send()
if err == nil {
t.Error("Expected rror 'SerializationError', but received nil")
}
if aerr, ok := err.(awserr.Error); ok && aerr.Code() != "SerializationError" {
t.Errorf("Expected 'SerializationError', but received %q", aerr.Code())
} else if !ok {
t.Errorf("Expected 'awserr.Error', but received %v", reflect.TypeOf(err))
} else if aerr.OrigErr().Error() != osErr.Error() {
t.Errorf("Expected %q, but received %q", osErr.Error(), aerr.OrigErr().Error())
}
if count != 6 {
t.Errorf("Expected '6', but received %d", count)
}
}
type testRetryer struct {
shouldRetry bool
}
func (d *testRetryer) MaxRetries() int {
return 3
}
// RetryRules returns the delay duration before retrying this request again
func (d *testRetryer) RetryRules(r *request.Request) time.Duration {
return time.Duration(time.Millisecond)
}
func (d *testRetryer) ShouldRetry(r *request.Request) bool {
d.shouldRetry = true
if r.Retryable != nil {
return *r.Retryable
}
if r.HTTPResponse.StatusCode >= 500 {
return true
}
return r.IsErrorRetryable()
}
func TestEnforceShouldRetryCheck(t *testing.T) {
tp := &http.Transport{
Proxy: http.ProxyFromEnvironment,
ResponseHeaderTimeout: 1 * time.Millisecond,
}
client := &http.Client{Transport: tp}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Millisecond)
}))
retryer := &testRetryer{}
s := awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
MaxRetries: aws.Int(0),
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
Retryer: retryer,
HTTPClient: client,
EnforceShouldRetryCheck: aws.Bool(true),
})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
if err == nil {
t.Fatalf("expect error, but got nil")
}
if e, a := 3, int(r.RetryCount); e != a {
t.Errorf("expect %d retry count, got %d", e, a)
}
if !retryer.shouldRetry {
t.Errorf("expect 'true' for ShouldRetry, but got %v", retryer.shouldRetry)
}
}
type errReader struct {
err error
}
func (reader *errReader) Read(b []byte) (int, error) {
return 0, reader.err
}
func (reader *errReader) Close() error {
return nil
}

View file

@ -26,8 +26,10 @@ func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config {
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
"RequestTimeout": {},
"RequestError": {},
"RequestTimeout": {},
ErrCodeResponseTimeout: {},
"RequestTimeoutException": {}, // Glacier's flavor of RequestTimeout
}
var throttleCodes = map[string]struct{}{
@ -68,35 +70,85 @@ func isCodeExpiredCreds(code string) bool {
return ok
}
var validParentCodes = map[string]struct{}{
ErrCodeSerialization: struct{}{},
ErrCodeRead: struct{}{},
}
func isNestedErrorRetryable(parentErr awserr.Error) bool {
if parentErr == nil {
return false
}
if _, ok := validParentCodes[parentErr.Code()]; !ok {
return false
}
err := parentErr.OrigErr()
if err == nil {
return false
}
if aerr, ok := err.(awserr.Error); ok {
return isCodeRetryable(aerr.Code())
}
return isErrConnectionReset(err)
}
// IsErrorRetryable returns whether the error is retryable, based on its Code.
// Returns false if the request has no Error set.
func (r *Request) IsErrorRetryable() bool {
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeRetryable(err.Code())
// Returns false if error is nil.
func IsErrorRetryable(err error) bool {
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeRetryable(aerr.Code()) || isNestedErrorRetryable(aerr)
}
}
return false
}
// IsErrorThrottle returns whether the error is to be throttled based on its code.
// Returns false if the request has no Error set
func (r *Request) IsErrorThrottle() bool {
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeThrottle(err.Code())
// Returns false if error is nil.
func IsErrorThrottle(err error) bool {
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeThrottle(aerr.Code())
}
}
return false
}
// IsErrorExpired returns whether the error code is a credential expiry error.
// Returns false if the request has no Error set.
func (r *Request) IsErrorExpired() bool {
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeExpiredCreds(err.Code())
// IsErrorExpiredCreds returns whether the error code is a credential expiry error.
// Returns false if error is nil.
func IsErrorExpiredCreds(err error) bool {
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeExpiredCreds(aerr.Code())
}
}
return false
}
// IsErrorRetryable returns whether the error is retryable, based on its Code.
// Returns false if the request has no Error set.
//
// Alias for the utility function IsErrorRetryable
func (r *Request) IsErrorRetryable() bool {
return IsErrorRetryable(r.Error)
}
// IsErrorThrottle returns whether the error is to be throttled based on its code.
// Returns false if the request has no Error set
//
// Alias for the utility function IsErrorThrottle
func (r *Request) IsErrorThrottle() bool {
return IsErrorThrottle(r.Error)
}
// IsErrorExpired returns whether the error code is a credential expiry error.
// Returns false if the request has no Error set.
//
// Alias for the utility function IsErrorExpiredCreds
func (r *Request) IsErrorExpired() bool {
return IsErrorExpiredCreds(r.Error)
}

View file

@ -0,0 +1,16 @@
package request
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws/awserr"
)
func TestRequestThrottling(t *testing.T) {
req := Request{}
req.Error = awserr.New("Throttling", "", nil)
assert.True(t, req.IsErrorThrottle())
}

View file

@ -0,0 +1,94 @@
package request
import (
"io"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var timeoutErr = awserr.New(
ErrCodeResponseTimeout,
"read on body has reached the timeout limit",
nil,
)
type readResult struct {
n int
err error
}
// timeoutReadCloser will handle body reads that take too long.
// We will return a ErrReadTimeout error if a timeout occurs.
type timeoutReadCloser struct {
reader io.ReadCloser
duration time.Duration
}
// Read will spin off a goroutine to call the reader's Read method. We will
// select on the timer's channel or the read's channel. Whoever completes first
// will be returned.
func (r *timeoutReadCloser) Read(b []byte) (int, error) {
timer := time.NewTimer(r.duration)
c := make(chan readResult, 1)
go func() {
n, err := r.reader.Read(b)
timer.Stop()
c <- readResult{n: n, err: err}
}()
select {
case data := <-c:
return data.n, data.err
case <-timer.C:
return 0, timeoutErr
}
}
func (r *timeoutReadCloser) Close() error {
return r.reader.Close()
}
const (
// HandlerResponseTimeout is what we use to signify the name of the
// response timeout handler.
HandlerResponseTimeout = "ResponseTimeoutHandler"
)
// adaptToResponseTimeoutError is a handler that will replace any top level error
// to a ErrCodeResponseTimeout, if its child is that.
func adaptToResponseTimeoutError(req *Request) {
if err, ok := req.Error.(awserr.Error); ok {
aerr, ok := err.OrigErr().(awserr.Error)
if ok && aerr.Code() == ErrCodeResponseTimeout {
req.Error = aerr
}
}
}
// WithResponseReadTimeout is a request option that will wrap the body in a timeout read closer.
// This will allow for per read timeouts. If a timeout occurred, we will return the
// ErrCodeResponseTimeout.
//
// svc.PutObjectWithContext(ctx, params, request.WithTimeoutReadCloser(30 * time.Second)
func WithResponseReadTimeout(duration time.Duration) Option {
return func(r *Request) {
var timeoutHandler = NamedHandler{
HandlerResponseTimeout,
func(req *Request) {
req.HTTPResponse.Body = &timeoutReadCloser{
reader: req.HTTPResponse.Body,
duration: duration,
}
}}
// remove the handler so we are not stomping over any new durations.
r.Handlers.Send.RemoveByName(HandlerResponseTimeout)
r.Handlers.Send.PushBackNamed(timeoutHandler)
r.Handlers.Unmarshal.PushBack(adaptToResponseTimeoutError)
r.Handlers.UnmarshalError.PushBack(adaptToResponseTimeoutError)
}
}

View file

@ -0,0 +1,76 @@
package request_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
)
func BenchmarkTimeoutReadCloser(b *testing.B) {
resp := `
{
"Bar": "qux"
}
`
handlers := request.Handlers{}
handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewBuffer([]byte(resp))),
}
})
handlers.Sign.PushBackNamed(v4.SignRequestHandler)
handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
op := &request.Operation{
Name: "op",
HTTPMethod: "POST",
HTTPPath: "/",
}
meta := metadata.ClientInfo{
ServiceName: "fooService",
SigningName: "foo",
SigningRegion: "foo",
Endpoint: "localhost",
APIVersion: "2001-01-01",
JSONVersion: "1.1",
TargetPrefix: "Foo",
}
req := request.New(
*unit.Session.Config,
meta,
handlers,
client.DefaultRetryer{NumMaxRetries: 5},
op,
&struct {
Foo *string
}{},
&struct {
Bar *string
}{},
)
req.ApplyOptions(request.WithResponseReadTimeout(15 * time.Second))
for i := 0; i < b.N; i++ {
err := req.Send()
if err != nil {
b.Errorf("Expected no error, but received %v", err)
}
}
}

View file

@ -0,0 +1,118 @@
package request
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type testReader struct {
duration time.Duration
count int
}
func (r *testReader) Read(b []byte) (int, error) {
if r.count > 0 {
r.count--
return len(b), nil
}
time.Sleep(r.duration)
return 0, io.EOF
}
func (r *testReader) Close() error {
return nil
}
func TestTimeoutReadCloser(t *testing.T) {
reader := timeoutReadCloser{
reader: &testReader{
duration: time.Second,
count: 5,
},
duration: time.Millisecond,
}
b := make([]byte, 100)
_, err := reader.Read(b)
if err != nil {
t.Log(err)
}
}
func TestTimeoutReadCloserSameDuration(t *testing.T) {
reader := timeoutReadCloser{
reader: &testReader{
duration: time.Millisecond,
count: 5,
},
duration: time.Millisecond,
}
b := make([]byte, 100)
_, err := reader.Read(b)
if err != nil {
t.Log(err)
}
}
func TestWithResponseReadTimeout(t *testing.T) {
r := Request{
HTTPResponse: &http.Response{
Body: ioutil.NopCloser(bytes.NewReader(nil)),
},
}
r.ApplyOptions(WithResponseReadTimeout(time.Second))
err := r.Send()
if err != nil {
t.Error(err)
}
v, ok := r.HTTPResponse.Body.(*timeoutReadCloser)
if !ok {
t.Error("Expected the body to be a timeoutReadCloser")
}
if v.duration != time.Second {
t.Errorf("Expected %v, but receive %v\n", time.Second, v.duration)
}
}
func TestAdaptToResponseTimeout(t *testing.T) {
testCases := []struct {
childErr error
r Request
expectedRootCode string
}{
{
childErr: awserr.New(ErrCodeResponseTimeout, "timeout!", nil),
r: Request{
Error: awserr.New("ErrTest", "FooBar", awserr.New(ErrCodeResponseTimeout, "timeout!", nil)),
},
expectedRootCode: ErrCodeResponseTimeout,
},
{
childErr: awserr.New(ErrCodeResponseTimeout+"1", "timeout!", nil),
r: Request{
Error: awserr.New("ErrTest", "FooBar", awserr.New(ErrCodeResponseTimeout+"1", "timeout!", nil)),
},
expectedRootCode: "ErrTest",
},
{
r: Request{
Error: awserr.New("ErrTest", "FooBar", nil),
},
expectedRootCode: "ErrTest",
},
}
for i, c := range testCases {
adaptToResponseTimeoutError(&c.r)
if aerr, ok := c.r.Error.(awserr.Error); !ok {
t.Errorf("Case %d: Expected 'awserr', but received %v", i+1, c.r.Error)
} else if aerr.Code() != c.expectedRootCode {
t.Errorf("Case %d: Expected %q, but received %s", i+1, c.expectedRootCode, aerr.Code())
}
}
}

287
vendor/github.com/aws/aws-sdk-go/aws/request/waiter.go generated vendored Normal file
View file

@ -0,0 +1,287 @@
package request
import (
"fmt"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
// WaiterResourceNotReadyErrorCode is the error code returned by a waiter when
// the waiter's max attempts have been exhausted.
const WaiterResourceNotReadyErrorCode = "ResourceNotReady"
// A WaiterOption is a function that will update the Waiter value's fields to
// configure the waiter.
type WaiterOption func(*Waiter)
// WithWaiterMaxAttempts returns the maximum number of times the waiter should
// attempt to check the resource for the target state.
func WithWaiterMaxAttempts(max int) WaiterOption {
return func(w *Waiter) {
w.MaxAttempts = max
}
}
// WaiterDelay will return a delay the waiter should pause between attempts to
// check the resource state. The passed in attempt is the number of times the
// Waiter has checked the resource state.
//
// Attempt is the number of attempts the Waiter has made checking the resource
// state.
type WaiterDelay func(attempt int) time.Duration
// ConstantWaiterDelay returns a WaiterDelay that will always return a constant
// delay the waiter should use between attempts. It ignores the number of
// attempts made.
func ConstantWaiterDelay(delay time.Duration) WaiterDelay {
return func(attempt int) time.Duration {
return delay
}
}
// WithWaiterDelay will set the Waiter to use the WaiterDelay passed in.
func WithWaiterDelay(delayer WaiterDelay) WaiterOption {
return func(w *Waiter) {
w.Delay = delayer
}
}
// WithWaiterLogger returns a waiter option to set the logger a waiter
// should use to log warnings and errors to.
func WithWaiterLogger(logger aws.Logger) WaiterOption {
return func(w *Waiter) {
w.Logger = logger
}
}
// WithWaiterRequestOptions returns a waiter option setting the request
// options for each request the waiter makes. Appends to waiter's request
// options already set.
func WithWaiterRequestOptions(opts ...Option) WaiterOption {
return func(w *Waiter) {
w.RequestOptions = append(w.RequestOptions, opts...)
}
}
// A Waiter provides the functionality to perform a blocking call which will
// wait for a resource state to be satisfied by a service.
//
// This type should not be used directly. The API operations provided in the
// service packages prefixed with "WaitUntil" should be used instead.
type Waiter struct {
Name string
Acceptors []WaiterAcceptor
Logger aws.Logger
MaxAttempts int
Delay WaiterDelay
RequestOptions []Option
NewRequest func([]Option) (*Request, error)
}
// ApplyOptions updates the waiter with the list of waiter options provided.
func (w *Waiter) ApplyOptions(opts ...WaiterOption) {
for _, fn := range opts {
fn(w)
}
}
// WaiterState are states the waiter uses based on WaiterAcceptor definitions
// to identify if the resource state the waiter is waiting on has occurred.
type WaiterState int
// String returns the string representation of the waiter state.
func (s WaiterState) String() string {
switch s {
case SuccessWaiterState:
return "success"
case FailureWaiterState:
return "failure"
case RetryWaiterState:
return "retry"
default:
return "unknown waiter state"
}
}
// States the waiter acceptors will use to identify target resource states.
const (
SuccessWaiterState WaiterState = iota // waiter successful
FailureWaiterState // waiter failed
RetryWaiterState // waiter needs to be retried
)
// WaiterMatchMode is the mode that the waiter will use to match the WaiterAcceptor
// definition's Expected attribute.
type WaiterMatchMode int
// Modes the waiter will use when inspecting API response to identify target
// resource states.
const (
PathAllWaiterMatch WaiterMatchMode = iota // match on all paths
PathWaiterMatch // match on specific path
PathAnyWaiterMatch // match on any path
PathListWaiterMatch // match on list of paths
StatusWaiterMatch // match on status code
ErrorWaiterMatch // match on error
)
// String returns the string representation of the waiter match mode.
func (m WaiterMatchMode) String() string {
switch m {
case PathAllWaiterMatch:
return "pathAll"
case PathWaiterMatch:
return "path"
case PathAnyWaiterMatch:
return "pathAny"
case PathListWaiterMatch:
return "pathList"
case StatusWaiterMatch:
return "status"
case ErrorWaiterMatch:
return "error"
default:
return "unknown waiter match mode"
}
}
// WaitWithContext will make requests for the API operation using NewRequest to
// build API requests. The request's response will be compared against the
// Waiter's Acceptors to determine the successful state of the resource the
// waiter is inspecting.
//
// The passed in context must not be nil. If it is nil a panic will occur. The
// Context will be used to cancel the waiter's pending requests and retry delays.
// Use aws.BackgroundContext if no context is available.
//
// The waiter will continue until the target state defined by the Acceptors,
// or the max attempts expires.
//
// Will return the WaiterResourceNotReadyErrorCode error code if the waiter's
// retryer ShouldRetry returns false. This normally will happen when the max
// wait attempts expires.
func (w Waiter) WaitWithContext(ctx aws.Context) error {
for attempt := 1; ; attempt++ {
req, err := w.NewRequest(w.RequestOptions)
if err != nil {
waiterLogf(w.Logger, "unable to create request %v", err)
return err
}
req.Handlers.Build.PushBack(MakeAddToUserAgentFreeFormHandler("Waiter"))
err = req.Send()
// See if any of the acceptors match the request's response, or error
for _, a := range w.Acceptors {
if matched, matchErr := a.match(w.Name, w.Logger, req, err); matched {
return matchErr
}
}
// The Waiter should only check the resource state MaxAttempts times
// This is here instead of in the for loop above to prevent delaying
// unnecessary when the waiter will not retry.
if attempt == w.MaxAttempts {
break
}
// Delay to wait before inspecting the resource again
delay := w.Delay(attempt)
if sleepFn := req.Config.SleepDelay; sleepFn != nil {
// Support SleepDelay for backwards compatibility and testing
sleepFn(delay)
} else if err := aws.SleepWithContext(ctx, delay); err != nil {
return awserr.New(CanceledErrorCode, "waiter context canceled", err)
}
}
return awserr.New(WaiterResourceNotReadyErrorCode, "exceeded wait attempts", nil)
}
// A WaiterAcceptor provides the information needed to wait for an API operation
// to complete.
type WaiterAcceptor struct {
State WaiterState
Matcher WaiterMatchMode
Argument string
Expected interface{}
}
// match returns if the acceptor found a match with the passed in request
// or error. True is returned if the acceptor made a match, error is returned
// if there was an error attempting to perform the match.
func (a *WaiterAcceptor) match(name string, l aws.Logger, req *Request, err error) (bool, error) {
result := false
var vals []interface{}
switch a.Matcher {
case PathAllWaiterMatch, PathWaiterMatch:
// Require all matches to be equal for result to match
vals, _ = awsutil.ValuesAtPath(req.Data, a.Argument)
if len(vals) == 0 {
break
}
result = true
for _, val := range vals {
if !awsutil.DeepEqual(val, a.Expected) {
result = false
break
}
}
case PathAnyWaiterMatch:
// Only a single match needs to equal for the result to match
vals, _ = awsutil.ValuesAtPath(req.Data, a.Argument)
for _, val := range vals {
if awsutil.DeepEqual(val, a.Expected) {
result = true
break
}
}
case PathListWaiterMatch:
// ignored matcher
case StatusWaiterMatch:
s := a.Expected.(int)
result = s == req.HTTPResponse.StatusCode
case ErrorWaiterMatch:
if aerr, ok := err.(awserr.Error); ok {
result = aerr.Code() == a.Expected.(string)
}
default:
waiterLogf(l, "WARNING: Waiter %s encountered unexpected matcher: %s",
name, a.Matcher)
}
if !result {
// If there was no matching result found there is nothing more to do
// for this response, retry the request.
return false, nil
}
switch a.State {
case SuccessWaiterState:
// waiter completed
return true, nil
case FailureWaiterState:
// Waiter failure state triggered
return true, awserr.New(WaiterResourceNotReadyErrorCode,
"failed waiting for successful resource state", err)
case RetryWaiterState:
// clear the error and retry the operation
return false, nil
default:
waiterLogf(l, "WARNING: Waiter %s encountered unexpected state: %s",
name, a.State)
return false, nil
}
}
func waiterLogf(logger aws.Logger, msg string, args ...interface{}) {
if logger != nil {
logger.Log(fmt.Sprintf(msg, args...))
}
}

View file

@ -0,0 +1,636 @@
package request_test
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
type mockClient struct {
*client.Client
}
type MockInput struct{}
type MockOutput struct {
States []*MockState
}
type MockState struct {
State *string
}
func (c *mockClient) MockRequest(input *MockInput) (*request.Request, *MockOutput) {
op := &request.Operation{
Name: "Mock",
HTTPMethod: "POST",
HTTPPath: "/",
}
if input == nil {
input = &MockInput{}
}
output := &MockOutput{}
req := c.NewRequest(op, input, output)
req.Data = output
return req, output
}
func BuildNewMockRequest(c *mockClient, in *MockInput) func([]request.Option) (*request.Request, error) {
return func(opts []request.Option) (*request.Request, error) {
req, _ := c.MockRequest(in)
req.ApplyOptions(opts...)
return req, nil
}
}
func TestWaiterPathAll(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathAllWaiterMatch,
Argument: "States[].State",
Expected: "running",
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext())
assert.NoError(t, err)
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
}
func TestWaiterPath(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathWaiterMatch,
Argument: "States[].State",
Expected: "running",
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext())
assert.NoError(t, err)
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
}
func TestWaiterFailure(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("stopping")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathAllWaiterMatch,
Argument: "States[].State",
Expected: "running",
},
{
State: request.FailureWaiterState,
Matcher: request.PathAnyWaiterMatch,
Argument: "States[].State",
Expected: "stopping",
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext()).(awserr.Error)
assert.Error(t, err)
assert.Equal(t, request.WaiterResourceNotReadyErrorCode, err.Code())
assert.Equal(t, "failed waiting for successful resource state", err.Message())
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
}
func TestWaiterError(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.UnmarshalError.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 1, error case retry
},
{ // Request 2, error case failure
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
},
},
}
reqErrs := make([]error, len(resps))
reqErrs[1] = awserr.New("MockException", "mock exception message", nil)
reqErrs[2] = awserr.New("FailureException", "mock failure exception message", nil)
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Send.PushBack(func(r *request.Request) {
code := 200
if reqNum == 1 {
code = 400
}
r.HTTPResponse = &http.Response{
StatusCode: code,
Status: http.StatusText(code),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) {
// If there was an error unmarshal error will be called instead of unmarshal
// need to increment count here also
if err := reqErrs[reqNum]; err != nil {
r.Error = err
reqNum++
}
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathAllWaiterMatch,
Argument: "States[].State",
Expected: "running",
},
{
State: request.RetryWaiterState,
Matcher: request.ErrorWaiterMatch,
Argument: "",
Expected: "MockException",
},
{
State: request.FailureWaiterState,
Matcher: request.ErrorWaiterMatch,
Argument: "",
Expected: "FailureException",
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext())
if err == nil {
t.Fatalf("expected error, but did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q error code, got %q", e, a)
}
if e, a := 3, numBuiltReq; e != a {
t.Errorf("expect %d built requests got %d", e, a)
}
if e, a := 3, reqNum; e != a {
t.Errorf("expect %d reqNum got %d", e, a)
}
}
func TestWaiterStatus(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
reqNum++
})
svc.Handlers.Send.PushBack(func(r *request.Request) {
code := 200
if reqNum == 3 {
code = 404
r.Error = awserr.New("NotFound", "Not Found", nil)
}
r.HTTPResponse = &http.Response{
StatusCode: code,
Status: http.StatusText(code),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
})
w := request.Waiter{
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(0),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Argument: "",
Expected: 404,
},
},
NewRequest: BuildNewMockRequest(svc, &MockInput{}),
}
err := w.WaitWithContext(aws.BackgroundContext())
assert.NoError(t, err)
assert.Equal(t, 3, reqNum)
}
func TestWaiter_ApplyOptions(t *testing.T) {
w := request.Waiter{}
logger := aws.NewDefaultLogger()
w.ApplyOptions(
request.WithWaiterLogger(logger),
request.WithWaiterRequestOptions(request.WithLogLevel(aws.LogDebug)),
request.WithWaiterMaxAttempts(2),
request.WithWaiterDelay(request.ConstantWaiterDelay(5*time.Second)),
)
if e, a := logger, w.Logger; e != a {
t.Errorf("expect logger to be set, and match, was not, %v, %v", e, a)
}
if len(w.RequestOptions) != 1 {
t.Fatalf("expect request options to be set to only a single option, %v", w.RequestOptions)
}
r := request.Request{}
r.ApplyOptions(w.RequestOptions...)
if e, a := aws.LogDebug, r.Config.LogLevel.Value(); e != a {
t.Errorf("expect %v loglevel got %v", e, a)
}
if e, a := 2, w.MaxAttempts; e != a {
t.Errorf("expect %d retryer max attempts, got %d", e, a)
}
if e, a := 5*time.Second, w.Delay(0); e != a {
t.Errorf("expect %d retryer delay, got %d", e, a)
}
}
func TestWaiter_WithContextCanceled(t *testing.T) {
c := awstesting.NewClient()
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
reqCount := 0
w := request.Waiter{
Name: "TestWaiter",
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
},
Logger: aws.NewDefaultLogger(),
NewRequest: func(opts []request.Option) (*request.Request, error) {
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound}
req.Handlers.Clear()
req.Data = struct{}{}
req.Handlers.Send.PushBack(func(r *request.Request) {
if reqCount == 1 {
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
}
reqCount++
})
return req, nil
},
}
err := w.WaitWithContext(ctx)
if err == nil {
t.Fatalf("expect waiter to be canceled.")
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q error code, got %q", e, a)
}
if e, a := 2, reqCount; e != a {
t.Errorf("expect %d requests, got %d", e, a)
}
}
func TestWaiter_WithContext(t *testing.T) {
c := awstesting.NewClient()
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
reqCount := 0
statuses := []int{http.StatusNotFound, http.StatusOK}
w := request.Waiter{
Name: "TestWaiter",
MaxAttempts: 10,
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
},
Logger: aws.NewDefaultLogger(),
NewRequest: func(opts []request.Option) (*request.Request, error) {
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.HTTPResponse = &http.Response{StatusCode: statuses[reqCount]}
req.Handlers.Clear()
req.Data = struct{}{}
req.Handlers.Send.PushBack(func(r *request.Request) {
if reqCount == 1 {
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
}
reqCount++
})
return req, nil
},
}
err := w.WaitWithContext(ctx)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := 2, reqCount; e != a {
t.Errorf("expect %d requests, got %d", e, a)
}
}
func TestWaiter_AttemptsExpires(t *testing.T) {
c := awstesting.NewClient()
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
reqCount := 0
w := request.Waiter{
Name: "TestWaiter",
MaxAttempts: 2,
Delay: request.ConstantWaiterDelay(1 * time.Millisecond),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
},
Logger: aws.NewDefaultLogger(),
NewRequest: func(opts []request.Option) (*request.Request, error) {
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound}
req.Handlers.Clear()
req.Data = struct{}{}
req.Handlers.Send.PushBack(func(r *request.Request) {
reqCount++
})
return req, nil
},
}
err := w.WaitWithContext(ctx)
if err == nil {
t.Fatalf("expect error did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q error code, got %q", e, a)
}
if e, a := 2, reqCount; e != a {
t.Errorf("expect %d requests, got %d", e, a)
}
}
func TestWaiterNilInput(t *testing.T) {
// Code generation doesn't have a great way to verify the code is correct
// other than being run via unit tests in the SDK. This should be fixed
// So code generation can be validated independently.
client := s3.New(unit.Session)
client.Handlers.Validate.Clear()
client.Handlers.Send.Clear() // mock sending
client.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
}
})
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Config.SleepDelay = func(dur time.Duration) {}
// Ensure waiters do not panic on nil input. It doesn't make sense to
// call a waiter without an input, Validation will
err := client.WaitUntilBucketExists(nil)
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
}
func TestWaiterWithContextNilInput(t *testing.T) {
// Code generation doesn't have a great way to verify the code is correct
// other than being run via unit tests in the SDK. This should be fixed
// So code generation can be validated independently.
client := s3.New(unit.Session)
client.Handlers.Validate.Clear()
client.Handlers.Send.Clear() // mock sending
client.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
}
})
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
// Ensure waiters do not panic on nil input
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
err := client.WaitUntilBucketExistsWithContext(ctx, nil,
request.WithWaiterDelay(request.ConstantWaiterDelay(0)),
request.WithWaiterMaxAttempts(1),
)
if err != nil {
t.Fatalf("expect no error, but got %v", err)
}
}

View file

@ -0,0 +1,319 @@
package session
import (
"bytes"
"crypto/tls"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func createTLSServer(cert, key []byte, done <-chan struct{}) (*httptest.Server, error) {
c, err := tls.X509KeyPair(cert, key)
if err != nil {
return nil, err
}
s := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
s.TLS = &tls.Config{
Certificates: []tls.Certificate{c},
}
s.TLS.BuildNameToCertificate()
s.StartTLS()
go func() {
<-done
s.Close()
}()
return s, nil
}
func setupTestCAFile(b []byte) (string, error) {
bundleFile, err := ioutil.TempFile(os.TempDir(), "aws-sdk-go-session-test")
if err != nil {
return "", err
}
_, err = bundleFile.Write(b)
if err != nil {
return "", err
}
defer bundleFile.Close()
return bundleFile.Name(), nil
}
func TestNewSession_WithCustomCABundle_Env(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
done := make(chan struct{})
server, err := createTLSServer(testTLSBundleCert, testTLSBundleKey, done)
assert.NoError(t, err)
// Write bundle to file
caFilename, err := setupTestCAFile(testTLSBundleCA)
defer func() {
os.Remove(caFilename)
}()
assert.NoError(t, err)
os.Setenv("AWS_CA_BUNDLE", caFilename)
s, err := NewSession(&aws.Config{
HTTPClient: &http.Client{},
Endpoint: aws.String(server.URL),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
})
assert.NoError(t, err)
assert.NotNil(t, s)
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
resp, err := s.Config.HTTPClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestNewSession_WithCustomCABundle_EnvNotExists(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_CA_BUNDLE", "file-not-exists")
s, err := NewSession()
assert.Error(t, err)
assert.Equal(t, "LoadCustomCABundleError", err.(awserr.Error).Code())
assert.Nil(t, s)
}
func TestNewSession_WithCustomCABundle_Option(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
done := make(chan struct{})
server, err := createTLSServer(testTLSBundleCert, testTLSBundleKey, done)
assert.NoError(t, err)
s, err := NewSessionWithOptions(Options{
Config: aws.Config{
HTTPClient: &http.Client{},
Endpoint: aws.String(server.URL),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
},
CustomCABundle: bytes.NewReader(testTLSBundleCA),
})
assert.NoError(t, err)
assert.NotNil(t, s)
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
resp, err := s.Config.HTTPClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestNewSession_WithCustomCABundle_OptionPriority(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
done := make(chan struct{})
server, err := createTLSServer(testTLSBundleCert, testTLSBundleKey, done)
assert.NoError(t, err)
os.Setenv("AWS_CA_BUNDLE", "file-not-exists")
s, err := NewSessionWithOptions(Options{
Config: aws.Config{
HTTPClient: &http.Client{},
Endpoint: aws.String(server.URL),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
},
CustomCABundle: bytes.NewReader(testTLSBundleCA),
})
assert.NoError(t, err)
assert.NotNil(t, s)
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
resp, err := s.Config.HTTPClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
type mockRoundTripper struct{}
func (m *mockRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
return nil, nil
}
func TestNewSession_WithCustomCABundle_UnsupportedTransport(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
s, err := NewSessionWithOptions(Options{
Config: aws.Config{
HTTPClient: &http.Client{
Transport: &mockRoundTripper{},
},
},
CustomCABundle: bytes.NewReader(testTLSBundleCA),
})
assert.Error(t, err)
assert.Equal(t, "LoadCustomCABundleError", err.(awserr.Error).Code())
assert.Contains(t, err.(awserr.Error).Message(), "transport unsupported type")
assert.Nil(t, s)
}
func TestNewSession_WithCustomCABundle_TransportSet(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
done := make(chan struct{})
server, err := createTLSServer(testTLSBundleCert, testTLSBundleKey, done)
assert.NoError(t, err)
s, err := NewSessionWithOptions(Options{
Config: aws.Config{
Endpoint: aws.String(server.URL),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
HTTPClient: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).Dial,
TLSHandshakeTimeout: 2 * time.Second,
},
},
},
CustomCABundle: bytes.NewReader(testTLSBundleCA),
})
assert.NoError(t, err)
assert.NotNil(t, s)
req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
resp, err := s.Config.HTTPClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
/* Cert generation steps
# Create the CA key
openssl genrsa -des3 -out ca.key 1024
# Create the CA Cert
openssl req -new -sha256 -x509 -days 3650 \
-subj "/C=GO/ST=Gopher/O=Testing ROOT CA" \
-key ca.key -out ca.crt
# Create config
cat > csr_details.txt <<-EOF
[req]
default_bits = 1024
prompt = no
default_md = sha256
req_extensions = SAN
distinguished_name = dn
[ dn ]
C=GO
ST=Gopher
O=Testing Certificate
OU=Testing IP
[SAN]
subjectAltName = IP:127.0.0.1
EOF
# Create certificate signing request
openssl req -new -sha256 -nodes -newkey rsa:1024 \
-config <( cat csr_details.txt ) \
-keyout ia.key -out ia.csr
# Create a signed certificate
openssl x509 -req -days 3650 \
-CAcreateserial \
-extfile <( cat csr_details.txt ) \
-extensions SAN \
-CA ca.crt -CAkey ca.key -in ia.csr -out ia.crt
# Verify
openssl req -noout -text -in ia.csr
openssl x509 -noout -text -in ia.crt
*/
var (
// ca.crt
testTLSBundleCA = []byte(`-----BEGIN CERTIFICATE-----
MIICiTCCAfKgAwIBAgIJAJ5X1olt05XjMA0GCSqGSIb3DQEBCwUAMDgxCzAJBgNV
BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD
QTAeFw0xNzAzMDkwMDAyMDZaFw0yNzAzMDcwMDAyMDZaMDgxCzAJBgNVBAYTAkdP
MQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBDQTCBnzAN
BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAw/8DN+t9XQR60jx42rsQ2WE2Dx85rb3n
GQxnKZZLNddsT8rDyxJNP18aFalbRbFlyln5fxWxZIblu9Xkm/HRhOpbSimSqo1y
uDx21NVZ1YsOvXpHby71jx3gPrrhSc/t/zikhi++6D/C6m1CiIGuiJ0GBiJxtrub
UBMXT0QtI2ECAwEAAaOBmjCBlzAdBgNVHQ4EFgQU8XG3X/YHBA6T04kdEkq6+4GV
YykwaAYDVR0jBGEwX4AU8XG3X/YHBA6T04kdEkq6+4GVYymhPKQ6MDgxCzAJBgNV
BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD
QYIJAJ5X1olt05XjMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADgYEAeILv
z49+uxmPcfOZzonuOloRcpdvyjiXblYxbzz6ch8GsE7Q886FTZbvwbgLhzdwSVgG
G8WHkodDUsymVepdqAamS3f8PdCUk8xIk9mop8LgaB9Ns0/TssxDvMr3sOD2Grb3
xyWymTWMcj6uCiEBKtnUp4rPiefcvCRYZ17/hLE=
-----END CERTIFICATE-----
`)
// ai.crt
testTLSBundleCert = []byte(`-----BEGIN CERTIFICATE-----
MIICGjCCAYOgAwIBAgIJAIIu+NOoxxM0MA0GCSqGSIb3DQEBBQUAMDgxCzAJBgNV
BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD
QTAeFw0xNzAzMDkwMDAzMTRaFw0yNzAzMDcwMDAzMTRaMFExCzAJBgNVBAYTAkdP
MQ8wDQYDVQQIDAZHb3BoZXIxHDAaBgNVBAoME1Rlc3RpbmcgQ2VydGlmaWNhdGUx
EzARBgNVBAsMClRlc3RpbmcgSVAwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGB
AN1hWHeioo/nASvbrjwCQzXCiWiEzGkw353NxsAB54/NqDL3LXNATtiSJu8kJBrm
Ah12IFLtWLGXjGjjYlHbQWnOR6awveeXnQZukJyRWh7m/Qlt9Ho0CgZE1U+832ac
5GWVldNxW1Lz4I+W9/ehzqe8I80RS6eLEKfUFXGiW+9RAgMBAAGjEzARMA8GA1Ud
EQQIMAaHBH8AAAEwDQYJKoZIhvcNAQEFBQADgYEAdF4WQHfVdPCbgv9sxgJjcR1H
Hgw9rZ47gO1IiIhzglnLXQ6QuemRiHeYFg4kjcYBk1DJguxzDTGnUwhUXOibAB+S
zssmrkdYYvn9aUhjc3XK3tjAoDpsPpeBeTBamuUKDHoH/dNRXxerZ8vu6uPR3Pgs
5v/KCV6IAEcvNyOXMPo=
-----END CERTIFICATE-----
`)
// ai.key
testTLSBundleKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQDdYVh3oqKP5wEr2648AkM1wolohMxpMN+dzcbAAeePzagy9y1z
QE7YkibvJCQa5gIddiBS7Vixl4xo42JR20FpzkemsL3nl50GbpCckVoe5v0JbfR6
NAoGRNVPvN9mnORllZXTcVtS8+CPlvf3oc6nvCPNEUunixCn1BVxolvvUQIDAQAB
AoGBAMISrcirddGrlLZLLrKC1ULS2T0cdkqdQtwHYn4+7S5+/z42vMx1iumHLsSk
rVY7X41OWkX4trFxhvEIrc/O48bo2zw78P7flTxHy14uxXnllU8cLThE29SlUU7j
AVBNxJZMsXMlS/DowwD4CjFe+x4Pu9wZcReF2Z9ntzMpySABAkEA+iWoJCPE2JpS
y78q3HYYgpNY3gF3JqQ0SI/zTNkb3YyEIUffEYq0Y9pK13HjKtdsSuX4osTIhQkS
+UgRp6tCAQJBAOKPYTfQ2FX8ijgUpHZRuEAVaxASAS0UATiLgzXxLvOh/VC2at5x
wjOX6sD65pPz/0D8Qj52Cq6Q1TQ+377SDVECQAIy0od+yPweXxvrUjUd1JlRMjbB
TIrKZqs8mKbUQapw0bh5KTy+O1elU4MRPS3jNtBxtP25PQnuSnxmZcFTgAECQFzg
DiiFcsn9FuRagfkHExMiNJuH5feGxeFaP9WzI144v9GAllrOI6Bm3JNzx2ZLlg4b
20Qju8lIEj6yr6JYFaECQHM1VSojGRKpOl9Ox/R4yYSA9RV5Gyn00/aJNxVYyPD5
i3acL2joQm2kLD/LO8paJ4+iQdRXCOMMIpjxSNjGQjQ=
-----END RSA PRIVATE KEY-----
`)
)

View file

@ -23,7 +23,7 @@ additional config if the AWS_SDK_LOAD_CONFIG environment variable is set.
Alternatively you can explicitly create a Session with shared config enabled.
To do this you can use NewSessionWithOptions to configure how the Session will
be created. Using the NewSessionWithOptions with SharedConfigState set to
SharedConfigEnabled will create the session as if the AWS_SDK_LOAD_CONFIG
SharedConfigEnable will create the session as if the AWS_SDK_LOAD_CONFIG
environment variable was set.
Creating Sessions
@ -45,16 +45,16 @@ region, and profile loaded from the environment and shared config automatically.
Requires the AWS_PROFILE to be set, or "default" is used.
// Create Session
sess, err := session.NewSession()
sess := session.Must(session.NewSession())
// Create a Session with a custom region
sess, err := session.NewSession(&aws.Config{Region: aws.String("us-east-1")})
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String("us-east-1"),
}))
// Create a S3 client instance from a session
sess, err := session.NewSession()
if err != nil {
// Handle Session creation error
}
sess := session.Must(session.NewSession())
svc := s3.New(sess)
Create Session With Option Overrides
@ -67,23 +67,25 @@ Use NewSessionWithOptions when you want to provide the config profile, or
override the shared config state (AWS_SDK_LOAD_CONFIG).
// Equivalent to session.NewSession()
sess, err := session.NewSessionWithOptions(session.Options{})
sess := session.Must(session.NewSessionWithOptions(session.Options{
// Options
}))
// Specify profile to load for the session's config
sess, err := session.NewSessionWithOptions(session.Options{
sess := session.Must(session.NewSessionWithOptions(session.Options{
Profile: "profile_name",
})
}))
// Specify profile for config and region for requests
sess, err := session.NewSessionWithOptions(session.Options{
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String("us-east-1")},
Profile: "profile_name",
})
}))
// Force enable Shared Config support
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: SharedConfigEnable,
})
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
Adding Handlers
@ -93,7 +95,8 @@ handler logs every request and its payload made by a service client:
// Create a session, and add additional handlers for all service
// clients created with the Session to inherit. Adds logging handler.
sess, err := session.NewSession()
sess := session.Must(session.NewSession())
sess.Handlers.Send.PushFront(func(r *request.Request) {
// Log every request made and its payload
logger.Println("Request: %s/%s, Payload: %s",
@ -121,9 +124,8 @@ file (~/.aws/config) and shared credentials file (~/.aws/credentials). Both
files have the same format.
If both config files are present the configuration from both files will be
read. The Session will be created from configuration values from the shared
credentials file (~/.aws/credentials) over those in the shared credentials
file (~/.aws/config).
read. The Session will be created from configuration values from the shared
credentials file (~/.aws/credentials) over those in the shared config file (~/.aws/config).
Credentials are the values the SDK should use for authenticating requests with
AWS Services. They arfrom a configuration file will need to include both
@ -138,15 +140,14 @@ the other two fields are also provided.
Assume Role values allow you to configure the SDK to assume an IAM role using
a set of credentials provided in a config file via the source_profile field.
Both "role_arn" and "source_profile" are required. The SDK does not support
assuming a role with MFA token Via the Session's constructor. You can use the
stscreds.AssumeRoleProvider credentials provider to specify custom
configuration and support for MFA.
Both "role_arn" and "source_profile" are required. The SDK supports assuming
a role with MFA token if the session option AssumeRoleTokenProvider
is set.
role_arn = arn:aws:iam::<account_number>:role/<role_name>
source_profile = profile_with_creds
external_id = 1234
mfa_serial = not supported!
mfa_serial = <serial or mfa arn>
role_session_name = session_name
Region is the region the SDK should use for looking up AWS service endpoints
@ -154,6 +155,37 @@ and signing requests.
region = us-east-1
Assume Role with MFA token
To create a session with support for assuming an IAM role with MFA set the
session option AssumeRoleTokenProvider to a function that will prompt for the
MFA token code when the SDK assumes the role and refreshes the role's credentials.
This allows you to configure the SDK via the shared config to assumea role
with MFA tokens.
In order for the SDK to assume a role with MFA the SharedConfigState
session option must be set to SharedConfigEnable, or AWS_SDK_LOAD_CONFIG
environment variable set.
The shared configuration instructs the SDK to assume an IAM role with MFA
when the mfa_serial configuration field is set in the shared config
(~/.aws/config) or shared credentials (~/.aws/credentials) file.
If mfa_serial is set in the configuration, the SDK will assume the role, and
the AssumeRoleTokenProvider session option is not set an an error will
be returned when creating the session.
sess := session.Must(session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
}))
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess)
To setup assume role outside of a session see the stscrds.AssumeRoleProvider
documentation.
Environment Variables
When a Session is created several environment variables can be set to adjust
@ -218,6 +250,24 @@ $HOME/.aws/config on Linux/Unix based systems, and
AWS_CONFIG_FILE=$HOME/my_shared_config
Path to a custom Credentials Authority (CA) bundle PEM file that the SDK
will use instead of the default system's root CA bundle. Use this only
if you want to replace the CA bundle the SDK uses for TLS requests.
AWS_CA_BUNDLE=$HOME/my_custom_ca_bundle
Enabling this option will attempt to merge the Transport into the SDK's HTTP
client. If the client's Transport is not a http.Transport an error will be
returned. If the Transport's TLS config is set this option will cause the SDK
to overwrite the Transport's TLS config's RootCAs value. If the CA bundle file
contains multiple certificates all of them will be loaded.
The Session option CustomCABundle is also available when creating sessions
to also enable this feature. CustomCABundle session option field has priority
over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
Setting a custom HTTPClient in the aws.Config options will override this setting.
To use this option and custom HTTP client, the HTTP client needs to be provided
when creating the session. Not the service client.
*/
package session

View file

@ -75,6 +75,24 @@ type envConfig struct {
//
// AWS_CONFIG_FILE=$HOME/my_shared_config
SharedConfigFile string
// Sets the path to a custom Credentials Authroity (CA) Bundle PEM file
// that the SDK will use instead of the the system's root CA bundle.
// Only use this if you want to configure the SDK to use a custom set
// of CAs.
//
// Enabling this option will attempt to merge the Transport
// into the SDK's HTTP client. If the client's Transport is
// not a http.Transport an error will be returned. If the
// Transport's TLS config is set this option will cause the
// SDK to overwrite the Transport's TLS config's RootCAs value.
//
// Setting a custom HTTPClient in the aws.Config options will override this setting.
// To use this option and custom HTTP client, the HTTP client needs to be provided
// when creating the session. Not the service client.
//
// AWS_CA_BUNDLE=$HOME/my_custom_ca_bundle
CustomCABundle string
}
var (
@ -150,6 +168,8 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
cfg.SharedCredentialsFile = sharedCredentialsFilename()
cfg.SharedConfigFile = sharedConfigFilename()
cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE")
return cfg
}

View file

@ -0,0 +1,291 @@
package session
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func TestLoadEnvConfig_Creds(t *testing.T) {
env := stashEnv()
defer popEnv(env)
cases := []struct {
Env map[string]string
Val credentials.Value
}{
{
Env: map[string]string{
"AWS_ACCESS_KEY": "AKID",
},
Val: credentials.Value{},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY_ID": "AKID",
},
Val: credentials.Value{},
},
{
Env: map[string]string{
"AWS_SECRET_KEY": "SECRET",
},
Val: credentials.Value{},
},
{
Env: map[string]string{
"AWS_SECRET_ACCESS_KEY": "SECRET",
},
Val: credentials.Value{},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY_ID": "AKID",
"AWS_SECRET_ACCESS_KEY": "SECRET",
},
Val: credentials.Value{
AccessKeyID: "AKID", SecretAccessKey: "SECRET",
ProviderName: "EnvConfigCredentials",
},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY": "AKID",
"AWS_SECRET_KEY": "SECRET",
},
Val: credentials.Value{
AccessKeyID: "AKID", SecretAccessKey: "SECRET",
ProviderName: "EnvConfigCredentials",
},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY": "AKID",
"AWS_SECRET_KEY": "SECRET",
"AWS_SESSION_TOKEN": "TOKEN",
},
Val: credentials.Value{
AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "TOKEN",
ProviderName: "EnvConfigCredentials",
},
},
}
for _, c := range cases {
os.Clearenv()
for k, v := range c.Env {
os.Setenv(k, v)
}
cfg := loadEnvConfig()
assert.Equal(t, c.Val, cfg.Creds)
}
}
func TestLoadEnvConfig(t *testing.T) {
env := stashEnv()
defer popEnv(env)
cases := []struct {
Env map[string]string
Region, Profile string
CustomCABundle string
UseSharedConfigCall bool
}{
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_PROFILE": "profile",
},
Region: "region", Profile: "profile",
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_DEFAULT_REGION": "default_region",
"AWS_PROFILE": "profile",
"AWS_DEFAULT_PROFILE": "default_profile",
},
Region: "region", Profile: "profile",
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_DEFAULT_REGION": "default_region",
"AWS_PROFILE": "profile",
"AWS_DEFAULT_PROFILE": "default_profile",
"AWS_SDK_LOAD_CONFIG": "1",
},
Region: "region", Profile: "profile",
},
{
Env: map[string]string{
"AWS_DEFAULT_REGION": "default_region",
"AWS_DEFAULT_PROFILE": "default_profile",
},
},
{
Env: map[string]string{
"AWS_DEFAULT_REGION": "default_region",
"AWS_DEFAULT_PROFILE": "default_profile",
"AWS_SDK_LOAD_CONFIG": "1",
},
Region: "default_region", Profile: "default_profile",
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_PROFILE": "profile",
},
Region: "region", Profile: "profile",
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_DEFAULT_REGION": "default_region",
"AWS_PROFILE": "profile",
"AWS_DEFAULT_PROFILE": "default_profile",
},
Region: "region", Profile: "profile",
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_REGION": "region",
"AWS_DEFAULT_REGION": "default_region",
"AWS_PROFILE": "profile",
"AWS_DEFAULT_PROFILE": "default_profile",
"AWS_SDK_LOAD_CONFIG": "1",
},
Region: "region", Profile: "profile",
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_DEFAULT_REGION": "default_region",
"AWS_DEFAULT_PROFILE": "default_profile",
},
Region: "default_region", Profile: "default_profile",
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_DEFAULT_REGION": "default_region",
"AWS_DEFAULT_PROFILE": "default_profile",
"AWS_SDK_LOAD_CONFIG": "1",
},
Region: "default_region", Profile: "default_profile",
UseSharedConfigCall: true,
},
{
Env: map[string]string{
"AWS_CA_BUNDLE": "custom_ca_bundle",
},
CustomCABundle: "custom_ca_bundle",
},
{
Env: map[string]string{
"AWS_CA_BUNDLE": "custom_ca_bundle",
},
CustomCABundle: "custom_ca_bundle",
UseSharedConfigCall: true,
},
}
for _, c := range cases {
os.Clearenv()
for k, v := range c.Env {
os.Setenv(k, v)
}
var cfg envConfig
if c.UseSharedConfigCall {
cfg = loadSharedEnvConfig()
} else {
cfg = loadEnvConfig()
}
assert.Equal(t, c.Region, cfg.Region)
assert.Equal(t, c.Profile, cfg.Profile)
assert.Equal(t, c.CustomCABundle, cfg.CustomCABundle)
}
}
func TestSharedCredsFilename(t *testing.T) {
env := stashEnv()
defer popEnv(env)
os.Setenv("USERPROFILE", "profile_dir")
expect := filepath.Join("profile_dir", ".aws", "credentials")
name := sharedCredentialsFilename()
assert.Equal(t, expect, name)
os.Setenv("HOME", "home_dir")
expect = filepath.Join("home_dir", ".aws", "credentials")
name = sharedCredentialsFilename()
assert.Equal(t, expect, name)
expect = filepath.Join("path/to/credentials/file")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", expect)
name = sharedCredentialsFilename()
assert.Equal(t, expect, name)
}
func TestSharedConfigFilename(t *testing.T) {
env := stashEnv()
defer popEnv(env)
os.Setenv("USERPROFILE", "profile_dir")
expect := filepath.Join("profile_dir", ".aws", "config")
name := sharedConfigFilename()
assert.Equal(t, expect, name)
os.Setenv("HOME", "home_dir")
expect = filepath.Join("home_dir", ".aws", "config")
name = sharedConfigFilename()
assert.Equal(t, expect, name)
expect = filepath.Join("path/to/config/file")
os.Setenv("AWS_CONFIG_FILE", expect)
name = sharedConfigFilename()
assert.Equal(t, expect, name)
}
func TestSetEnvValue(t *testing.T) {
env := stashEnv()
defer popEnv(env)
os.Setenv("empty_key", "")
os.Setenv("second_key", "2")
os.Setenv("third_key", "3")
var dst string
setFromEnvVal(&dst, []string{
"empty_key", "first_key", "second_key", "third_key",
})
assert.Equal(t, "2", dst)
}
func stashEnv() []string {
env := os.Environ()
os.Clearenv()
return env
}
func popEnv(env []string) {
os.Clearenv()
for _, e := range env {
p := strings.SplitN(e, "=", 2)
os.Setenv(p[0], p[1])
}
}

View file

@ -1,7 +1,13 @@
package session
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
@ -52,7 +58,7 @@ func New(cfgs ...*aws.Config) *Session {
envCfg := loadEnvConfig()
if envCfg.EnableSharedConfig {
s, err := newSession(envCfg, cfgs...)
s, err := newSession(Options{}, envCfg, cfgs...)
if err != nil {
// Old session.New expected all errors to be discovered when
// a request is made, and would report the errors then. This
@ -73,7 +79,7 @@ func New(cfgs ...*aws.Config) *Session {
return s
}
return oldNewSession(cfgs...)
return deprecatedNewSession(cfgs...)
}
// NewSession returns a new Session created from SDK defaults, config files,
@ -92,9 +98,10 @@ func New(cfgs ...*aws.Config) *Session {
// control through code how the Session will be created. Such as specifying the
// config profile, and controlling if shared config is enabled or not.
func NewSession(cfgs ...*aws.Config) (*Session, error) {
envCfg := loadEnvConfig()
opts := Options{}
opts.Config.MergeIn(cfgs...)
return newSession(envCfg, cfgs...)
return NewSessionWithOptions(opts)
}
// SharedConfigState provides the ability to optionally override the state
@ -147,6 +154,41 @@ type Options struct {
// will allow you to override the AWS_SDK_LOAD_CONFIG environment variable
// and enable or disable the shared config functionality.
SharedConfigState SharedConfigState
// When the SDK's shared config is configured to assume a role with MFA
// this option is required in order to provide the mechanism that will
// retrieve the MFA token. There is no default value for this field. If
// it is not set an error will be returned when creating the session.
//
// This token provider will be called when ever the assumed role's
// credentials need to be refreshed. Within the context of service clients
// all sharing the same session the SDK will ensure calls to the token
// provider are atomic. When sharing a token provider across multiple
// sessions additional synchronization logic is needed to ensure the
// token providers do not introduce race conditions. It is recommend to
// share the session where possible.
//
// stscreds.StdinTokenProvider is a basic implementation that will prompt
// from stdin for the MFA token code.
//
// This field is only used if the shared configuration is enabled, and
// the config enables assume role wit MFA via the mfa_serial field.
AssumeRoleTokenProvider func() (string, error)
// Reader for a custom Credentials Authority (CA) bundle in PEM format that
// the SDK will use instead of the default system's root CA bundle. Use this
// only if you want to replace the CA bundle the SDK uses for TLS requests.
//
// Enabling this option will attempt to merge the Transport into the SDK's HTTP
// client. If the client's Transport is not a http.Transport an error will be
// returned. If the Transport's TLS config is set this option will cause the SDK
// to overwrite the Transport's TLS config's RootCAs value. If the CA
// bundle reader contains multiple certificates all of them will be loaded.
//
// The Session option CustomCABundle is also available when creating sessions
// to also enable this feature. CustomCABundle session option field has priority
// over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
CustomCABundle io.Reader
}
// NewSessionWithOptions returns a new Session created from SDK defaults, config files,
@ -161,23 +203,23 @@ type Options struct {
// to be built with retrieving credentials with AssumeRole set in the config.
//
// // Equivalent to session.New
// sess, err := session.NewSessionWithOptions(session.Options{})
// sess := session.Must(session.NewSessionWithOptions(session.Options{}))
//
// // Specify profile to load for the session's config
// sess, err := session.NewSessionWithOptions(session.Options{
// sess := session.Must(session.NewSessionWithOptions(session.Options{
// Profile: "profile_name",
// })
// }))
//
// // Specify profile for config and region for requests
// sess, err := session.NewSessionWithOptions(session.Options{
// sess := session.Must(session.NewSessionWithOptions(session.Options{
// Config: aws.Config{Region: aws.String("us-east-1")},
// Profile: "profile_name",
// })
// }))
//
// // Force enable Shared Config support
// sess, err := session.NewSessionWithOptions(session.Options{
// SharedConfigState: SharedConfigEnable,
// })
// sess := session.Must(session.NewSessionWithOptions(session.Options{
// SharedConfigState: session.SharedConfigEnable,
// }))
func NewSessionWithOptions(opts Options) (*Session, error) {
var envCfg envConfig
if opts.SharedConfigState == SharedConfigEnable {
@ -197,7 +239,18 @@ func NewSessionWithOptions(opts Options) (*Session, error) {
envCfg.EnableSharedConfig = true
}
return newSession(envCfg, &opts.Config)
// Only use AWS_CA_BUNDLE if session option is not provided.
if len(envCfg.CustomCABundle) != 0 && opts.CustomCABundle == nil {
f, err := os.Open(envCfg.CustomCABundle)
if err != nil {
return nil, awserr.New("LoadCustomCABundleError",
"failed to open custom CA bundle PEM file", err)
}
defer f.Close()
opts.CustomCABundle = f
}
return newSession(opts, envCfg, &opts.Config)
}
// Must is a helper function to ensure the Session is valid and there was no
@ -215,7 +268,7 @@ func Must(sess *Session, err error) *Session {
return sess
}
func oldNewSession(cfgs ...*aws.Config) *Session {
func deprecatedNewSession(cfgs ...*aws.Config) *Session {
cfg := defaults.Config()
handlers := defaults.Handlers()
@ -242,7 +295,7 @@ func oldNewSession(cfgs ...*aws.Config) *Session {
return s
}
func newSession(envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
cfg := defaults.Config()
handlers := defaults.Handlers()
@ -266,7 +319,9 @@ func newSession(envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
return nil, err
}
mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers)
if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil {
return nil, err
}
s := &Session{
Config: cfg,
@ -275,10 +330,62 @@ func newSession(envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
initHandlers(s)
// Setup HTTP client with custom cert bundle if enabled
if opts.CustomCABundle != nil {
if err := loadCustomCABundle(s, opts.CustomCABundle); err != nil {
return nil, err
}
}
return s, nil
}
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers) {
func loadCustomCABundle(s *Session, bundle io.Reader) error {
var t *http.Transport
switch v := s.Config.HTTPClient.Transport.(type) {
case *http.Transport:
t = v
default:
if s.Config.HTTPClient.Transport != nil {
return awserr.New("LoadCustomCABundleError",
"unable to load custom CA bundle, HTTPClient's transport unsupported type", nil)
}
}
if t == nil {
t = &http.Transport{}
}
p, err := loadCertPool(bundle)
if err != nil {
return err
}
if t.TLSClientConfig == nil {
t.TLSClientConfig = &tls.Config{}
}
t.TLSClientConfig.RootCAs = p
s.Config.HTTPClient.Transport = t
return nil
}
func loadCertPool(r io.Reader) (*x509.CertPool, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, awserr.New("LoadCustomCABundleError",
"failed to read custom CA bundle PEM file", err)
}
p := x509.NewCertPool()
if !p.AppendCertsFromPEM(b) {
return nil, awserr.New("LoadCustomCABundleError",
"failed to load custom CA bundle PEM file", err)
}
return p, nil
}
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, sessOpts Options) error {
// Merge in user provided configuration
cfg.MergeIn(userCfg)
@ -302,6 +409,11 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.AssumeRoleSource.Creds,
)
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return AssumeRoleTokenProviderNotSetError{}
}
cfg.Credentials = stscreds.NewCredentials(
&Session{
Config: &cfgCp,
@ -311,11 +423,16 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName
// Assume role with external ID
if len(sharedCfg.AssumeRole.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID)
}
// MFA not supported
// Assume role with MFA
if len(sharedCfg.AssumeRole.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
}
},
)
} else if len(sharedCfg.Creds.AccessKeyID) > 0 {
@ -336,6 +453,33 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
})
}
}
return nil
}
// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the
// MFAToken option is not set when shared config is configured load assume a
// role with an MFA token.
type AssumeRoleTokenProviderNotSetError struct{}
// Code is the short id of the error.
func (e AssumeRoleTokenProviderNotSetError) Code() string {
return "AssumeRoleTokenProviderNotSetError"
}
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {

View file

@ -0,0 +1,404 @@
package session
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestNewDefaultSession(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
s := New(&aws.Config{Region: aws.String("region")})
assert.Equal(t, "region", *s.Config.Region)
assert.Equal(t, http.DefaultClient, s.Config.HTTPClient)
assert.NotNil(t, s.Config.Logger)
assert.Equal(t, aws.LogOff, *s.Config.LogLevel)
}
func TestNew_WithCustomCreds(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
customCreds := credentials.NewStaticCredentials("AKID", "SECRET", "TOKEN")
s := New(&aws.Config{Credentials: customCreds})
assert.Equal(t, customCreds, s.Config.Credentials)
}
type mockLogger struct {
*bytes.Buffer
}
func (w mockLogger) Log(args ...interface{}) {
fmt.Fprintln(w, args...)
}
func TestNew_WithSessionLoadError(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_CONFIG_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_invalid_source_profile")
logger := bytes.Buffer{}
s := New(&aws.Config{Logger: &mockLogger{&logger}})
assert.NotNil(t, s)
svc := s3.New(s)
_, err := svc.ListBuckets(&s3.ListBucketsInput{})
assert.Error(t, err)
assert.Contains(t, logger.String(), "ERROR: failed to create session with AWS_SDK_LOAD_CONFIG enabled")
assert.Contains(t, err.Error(), SharedConfigAssumeRoleError{
RoleARN: "assume_role_invalid_source_profile_role_arn",
}.Error())
}
func TestSessionCopy(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_REGION", "orig_region")
s := Session{
Config: defaults.Config(),
Handlers: defaults.Handlers(),
}
newSess := s.Copy(&aws.Config{Region: aws.String("new_region")})
assert.Equal(t, "orig_region", *s.Config.Region)
assert.Equal(t, "new_region", *newSess.Config.Region)
}
func TestSessionClientConfig(t *testing.T) {
s, err := NewSession(&aws.Config{Region: aws.String("orig_region")})
assert.NoError(t, err)
cfg := s.ClientConfig("s3", &aws.Config{Region: aws.String("us-west-2")})
assert.Equal(t, "https://s3-us-west-2.amazonaws.com", cfg.Endpoint)
assert.Equal(t, "us-west-2", cfg.SigningRegion)
assert.Equal(t, "us-west-2", *cfg.Config.Region)
}
func TestNewSession_NoCredentials(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
s, err := NewSession()
assert.NoError(t, err)
assert.NotNil(t, s.Config.Credentials)
assert.NotEqual(t, credentials.AnonymousCredentials, s.Config.Credentials)
}
func TestNewSessionWithOptions_OverrideProfile(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "other_profile")
s, err := NewSessionWithOptions(Options{
Profile: "full_profile",
})
assert.NoError(t, err)
assert.Equal(t, "full_profile_region", *s.Config.Region)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "full_profile_akid", creds.AccessKeyID)
assert.Equal(t, "full_profile_secret", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
}
func TestNewSessionWithOptions_OverrideSharedConfigEnable(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "0")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "full_profile")
s, err := NewSessionWithOptions(Options{
SharedConfigState: SharedConfigEnable,
})
assert.NoError(t, err)
assert.Equal(t, "full_profile_region", *s.Config.Region)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "full_profile_akid", creds.AccessKeyID)
assert.Equal(t, "full_profile_secret", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
}
func TestNewSessionWithOptions_OverrideSharedConfigDisable(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "full_profile")
s, err := NewSessionWithOptions(Options{
SharedConfigState: SharedConfigDisable,
})
assert.NoError(t, err)
assert.Empty(t, *s.Config.Region)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "full_profile_akid", creds.AccessKeyID)
assert.Equal(t, "full_profile_secret", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
}
func TestNewSessionWithOptions_Overrides(t *testing.T) {
cases := []struct {
InEnvs map[string]string
InProfile string
OutRegion string
OutCreds credentials.Value
}{
{
InEnvs: map[string]string{
"AWS_SDK_LOAD_CONFIG": "0",
"AWS_SHARED_CREDENTIALS_FILE": testConfigFilename,
"AWS_PROFILE": "other_profile",
},
InProfile: "full_profile",
OutRegion: "full_profile_region",
OutCreds: credentials.Value{
AccessKeyID: "full_profile_akid",
SecretAccessKey: "full_profile_secret",
ProviderName: "SharedConfigCredentials",
},
},
{
InEnvs: map[string]string{
"AWS_SDK_LOAD_CONFIG": "0",
"AWS_SHARED_CREDENTIALS_FILE": testConfigFilename,
"AWS_REGION": "env_region",
"AWS_ACCESS_KEY": "env_akid",
"AWS_SECRET_ACCESS_KEY": "env_secret",
"AWS_PROFILE": "other_profile",
},
InProfile: "full_profile",
OutRegion: "env_region",
OutCreds: credentials.Value{
AccessKeyID: "env_akid",
SecretAccessKey: "env_secret",
ProviderName: "EnvConfigCredentials",
},
},
{
InEnvs: map[string]string{
"AWS_SDK_LOAD_CONFIG": "0",
"AWS_SHARED_CREDENTIALS_FILE": testConfigFilename,
"AWS_CONFIG_FILE": testConfigOtherFilename,
"AWS_PROFILE": "shared_profile",
},
InProfile: "config_file_load_order",
OutRegion: "shared_config_region",
OutCreds: credentials.Value{
AccessKeyID: "shared_config_akid",
SecretAccessKey: "shared_config_secret",
ProviderName: "SharedConfigCredentials",
},
},
}
for _, c := range cases {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
for k, v := range c.InEnvs {
os.Setenv(k, v)
}
s, err := NewSessionWithOptions(Options{
Profile: c.InProfile,
SharedConfigState: SharedConfigEnable,
})
assert.NoError(t, err)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, c.OutRegion, *s.Config.Region)
assert.Equal(t, c.OutCreds.AccessKeyID, creds.AccessKeyID)
assert.Equal(t, c.OutCreds.SecretAccessKey, creds.SecretAccessKey)
assert.Equal(t, c.OutCreds.SessionToken, creds.SessionToken)
assert.Contains(t, creds.ProviderName, c.OutCreds.ProviderName)
}
}
const assumeRoleRespMsg = `
<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleResult>
<AssumedRoleUser>
<Arn>arn:aws:sts::account_id:assumed-role/role/session_name</Arn>
<AssumedRoleId>AKID:session_name</AssumedRoleId>
</AssumedRoleUser>
<Credentials>
<AccessKeyId>AKID</AccessKeyId>
<SecretAccessKey>SECRET</SecretAccessKey>
<SessionToken>SESSION_TOKEN</SessionToken>
<Expiration>%s</Expiration>
</Credentials>
</AssumeRoleResult>
<ResponseMetadata>
<RequestId>request-id</RequestId>
</ResponseMetadata>
</AssumeRoleResponse>
`
func TestSesisonAssumeRole(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_REGION", "us-east-1")
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_w_creds")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))
}))
s, err := NewSession(&aws.Config{Endpoint: aws.String(server.URL), DisableSSL: aws.Bool(true)})
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Equal(t, "SESSION_TOKEN", creds.SessionToken)
assert.Contains(t, creds.ProviderName, "AssumeRoleProvider")
}
func TestSessionAssumeRole_WithMFA(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_REGION", "us-east-1")
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_w_creds")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.FormValue("SerialNumber"), "0123456789")
assert.Equal(t, r.FormValue("TokenCode"), "tokencode")
w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))
}))
customProviderCalled := false
sess, err := NewSessionWithOptions(Options{
Profile: "assume_role_w_mfa",
Config: aws.Config{
Region: aws.String("us-east-1"),
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
},
SharedConfigState: SharedConfigEnable,
AssumeRoleTokenProvider: func() (string, error) {
customProviderCalled = true
return "tokencode", nil
},
})
assert.NoError(t, err)
creds, err := sess.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, customProviderCalled)
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Equal(t, "SESSION_TOKEN", creds.SessionToken)
assert.Contains(t, creds.ProviderName, "AssumeRoleProvider")
}
func TestSessionAssumeRole_WithMFA_NoTokenProvider(t *testing.T) {
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_REGION", "us-east-1")
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_w_creds")
_, err := NewSessionWithOptions(Options{
Profile: "assume_role_w_mfa",
SharedConfigState: SharedConfigEnable,
})
assert.Equal(t, err, AssumeRoleTokenProviderNotSetError{})
}
func TestSessionAssumeRole_DisableSharedConfig(t *testing.T) {
// Backwards compatibility with Shared config disabled
// assume role should not be built into the config.
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "0")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_w_creds")
s, err := NewSession()
assert.NoError(t, err)
creds, err := s.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, "assume_role_w_creds_akid", creds.AccessKeyID)
assert.Equal(t, "assume_role_w_creds_secret", creds.SecretAccessKey)
assert.Contains(t, creds.ProviderName, "SharedConfigCredentials")
}
func TestSessionAssumeRole_InvalidSourceProfile(t *testing.T) {
// Backwards compatibility with Shared config disabled
// assume role should not be built into the config.
oldEnv := initSessionTestEnv()
defer popEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
os.Setenv("AWS_PROFILE", "assume_role_invalid_source_profile")
s, err := NewSession()
assert.Error(t, err)
assert.Contains(t, err.Error(), "SharedConfigAssumeRoleError: failed to load assume role")
assert.Nil(t, s)
}
func initSessionTestEnv() (oldEnv []string) {
oldEnv = stashEnv()
os.Setenv("AWS_CONFIG_FILE", "file_not_exists")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "file_not_exists")
return oldEnv
}

View file

@ -0,0 +1,274 @@
package session
import (
"fmt"
"path/filepath"
"testing"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/go-ini/ini"
"github.com/stretchr/testify/assert"
)
var (
testConfigFilename = filepath.Join("testdata", "shared_config")
testConfigOtherFilename = filepath.Join("testdata", "shared_config_other")
)
func TestLoadSharedConfig(t *testing.T) {
cases := []struct {
Filenames []string
Profile string
Expected sharedConfig
Err error
}{
{
Filenames: []string{"file_not_exists"},
Profile: "default",
},
{
Filenames: []string{testConfigFilename},
Expected: sharedConfig{
Region: "default_region",
},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "config_file_load_order",
Expected: sharedConfig{
Region: "shared_config_region",
Creds: credentials.Value{
AccessKeyID: "shared_config_akid",
SecretAccessKey: "shared_config_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
{
Filenames: []string{testConfigFilename, testConfigOtherFilename},
Profile: "config_file_load_order",
Expected: sharedConfig{
Region: "shared_config_other_region",
Creds: credentials.Value{
AccessKeyID: "shared_config_other_akid",
SecretAccessKey: "shared_config_other_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigOtherFilename),
},
},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "assume_role",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_role_arn",
SourceProfile: "complete_creds",
},
AssumeRoleSource: &sharedConfig{
Creds: credentials.Value{
AccessKeyID: "complete_creds_akid",
SecretAccessKey: "complete_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "assume_role_invalid_source_profile",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_invalid_source_profile_role_arn",
SourceProfile: "profile_not_exists",
},
},
Err: SharedConfigAssumeRoleError{RoleARN: "assume_role_invalid_source_profile_role_arn"},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "assume_role_w_creds",
Expected: sharedConfig{
Creds: credentials.Value{
AccessKeyID: "assume_role_w_creds_akid",
SecretAccessKey: "assume_role_w_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_w_creds_role_arn",
SourceProfile: "assume_role_w_creds",
ExternalID: "1234",
RoleSessionName: "assume_role_w_creds_session_name",
},
AssumeRoleSource: &sharedConfig{
Creds: credentials.Value{
AccessKeyID: "assume_role_w_creds_akid",
SecretAccessKey: "assume_role_w_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
},
{
Filenames: []string{testConfigOtherFilename, testConfigFilename},
Profile: "assume_role_wo_creds",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_wo_creds_role_arn",
SourceProfile: "assume_role_wo_creds",
},
},
Err: SharedConfigAssumeRoleError{RoleARN: "assume_role_wo_creds_role_arn"},
},
{
Filenames: []string{filepath.Join("testdata", "shared_config_invalid_ini")},
Profile: "profile_name",
Err: SharedConfigLoadError{Filename: filepath.Join("testdata", "shared_config_invalid_ini")},
},
}
for i, c := range cases {
cfg, err := loadSharedConfig(c.Profile, c.Filenames)
if c.Err != nil {
assert.Contains(t, err.Error(), c.Err.Error(), "expected error, %d", i)
continue
}
assert.NoError(t, err, "unexpected error, %d", i)
assert.Equal(t, c.Expected, cfg, "not equal, %d", i)
}
}
func TestLoadSharedConfigFromFile(t *testing.T) {
filename := testConfigFilename
f, err := ini.Load(filename)
if err != nil {
t.Fatalf("failed to load test config file, %s, %v", filename, err)
}
iniFile := sharedConfigFile{IniData: f, Filename: filename}
cases := []struct {
Profile string
Expected sharedConfig
Err error
}{
{
Profile: "default",
Expected: sharedConfig{Region: "default_region"},
},
{
Profile: "alt_profile_name",
Expected: sharedConfig{Region: "alt_profile_name_region"},
},
{
Profile: "short_profile_name_first",
Expected: sharedConfig{Region: "short_profile_name_first_short"},
},
{
Profile: "partial_creds",
Expected: sharedConfig{},
},
{
Profile: "complete_creds",
Expected: sharedConfig{
Creds: credentials.Value{
AccessKeyID: "complete_creds_akid",
SecretAccessKey: "complete_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
{
Profile: "complete_creds_with_token",
Expected: sharedConfig{
Creds: credentials.Value{
AccessKeyID: "complete_creds_with_token_akid",
SecretAccessKey: "complete_creds_with_token_secret",
SessionToken: "complete_creds_with_token_token",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
},
},
{
Profile: "full_profile",
Expected: sharedConfig{
Creds: credentials.Value{
AccessKeyID: "full_profile_akid",
SecretAccessKey: "full_profile_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
},
Region: "full_profile_region",
},
},
{
Profile: "partial_assume_role",
Expected: sharedConfig{},
},
{
Profile: "assume_role",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_role_arn",
SourceProfile: "complete_creds",
},
},
},
{
Profile: "assume_role_w_mfa",
Expected: sharedConfig{
AssumeRole: assumeRoleConfig{
RoleARN: "assume_role_role_arn",
SourceProfile: "complete_creds",
MFASerial: "0123456789",
},
},
},
{
Profile: "does_not_exists",
Err: SharedConfigProfileNotExistsError{Profile: "does_not_exists"},
},
}
for i, c := range cases {
cfg := sharedConfig{}
err := cfg.setFromIniFile(c.Profile, iniFile)
if c.Err != nil {
assert.Contains(t, err.Error(), c.Err.Error(), "expected error, %d", i)
continue
}
assert.NoError(t, err, "unexpected error, %d", i)
assert.Equal(t, c.Expected, cfg, "not equal, %d", i)
}
}
func TestLoadSharedConfigIniFiles(t *testing.T) {
cases := []struct {
Filenames []string
Expected []sharedConfigFile
}{
{
Filenames: []string{"not_exists", testConfigFilename},
Expected: []sharedConfigFile{
{Filename: testConfigFilename},
},
},
{
Filenames: []string{testConfigFilename, testConfigOtherFilename},
Expected: []sharedConfigFile{
{Filename: testConfigFilename},
{Filename: testConfigOtherFilename},
},
},
}
for i, c := range cases {
files, err := loadSharedConfigIniFiles(c.Filenames)
assert.NoError(t, err, "unexpected error, %d", i)
assert.Equal(t, len(c.Expected), len(files), "expected num files, %d", i)
for i, expectedFile := range c.Expected {
assert.Equal(t, expectedFile.Filename, files[i].Filename)
}
}
}

View file

@ -0,0 +1,65 @@
[default]
s3 =
unsupported_key=123
other_unsupported=abc
region = default_region
[profile alt_profile_name]
region = alt_profile_name_region
[short_profile_name_first]
region = short_profile_name_first_short
[profile short_profile_name_first]
region = short_profile_name_first_alt
[partial_creds]
aws_access_key_id = partial_creds_akid
[complete_creds]
aws_access_key_id = complete_creds_akid
aws_secret_access_key = complete_creds_secret
[complete_creds_with_token]
aws_access_key_id = complete_creds_with_token_akid
aws_secret_access_key = complete_creds_with_token_secret
aws_session_token = complete_creds_with_token_token
[full_profile]
aws_access_key_id = full_profile_akid
aws_secret_access_key = full_profile_secret
region = full_profile_region
[config_file_load_order]
region = shared_config_region
aws_access_key_id = shared_config_akid
aws_secret_access_key = shared_config_secret
[partial_assume_role]
role_arn = partial_assume_role_role_arn
[assume_role]
role_arn = assume_role_role_arn
source_profile = complete_creds
[assume_role_w_mfa]
role_arn = assume_role_role_arn
source_profile = complete_creds
mfa_serial = 0123456789
[assume_role_invalid_source_profile]
role_arn = assume_role_invalid_source_profile_role_arn
source_profile = profile_not_exists
[assume_role_w_creds]
role_arn = assume_role_w_creds_role_arn
source_profile = assume_role_w_creds
external_id = 1234
role_session_name = assume_role_w_creds_session_name
aws_access_key_id = assume_role_w_creds_akid
aws_secret_access_key = assume_role_w_creds_secret
[assume_role_wo_creds]
role_arn = assume_role_wo_creds_role_arn
source_profile = assume_role_wo_creds

View file

@ -0,0 +1 @@
[profile_nam

View file

@ -0,0 +1,17 @@
[default]
region = default_region
[partial_creds]
aws_access_key_id = AKID
[profile alt_profile_name]
region = alt_profile_name_region
[creds_from_credentials]
aws_access_key_id = creds_from_config_akid
aws_secret_access_key = creds_from_config_secret
[config_file_load_order]
region = shared_config_other_region
aws_access_key_id = shared_config_other_akid
aws_secret_access_key = shared_config_other_secret

View file

@ -0,0 +1,67 @@
// +build go1.5
package v4_test
import (
"fmt"
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)
func TestStandaloneSign(t *testing.T) {
creds := unit.Session.Config.Credentials
signer := v4.NewSigner(creds)
for _, c := range standaloneSignCases {
host := fmt.Sprintf("https://%s.%s.%s.amazonaws.com",
c.SubDomain, c.Region, c.Service)
req, err := http.NewRequest("GET", host, nil)
assert.NoError(t, err)
// URL.EscapedPath() will be used by the signer to get the
// escaped form of the request's URI path.
req.URL.Path = c.OrigURI
req.URL.RawQuery = c.OrigQuery
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
assert.NoError(t, err)
actual := req.Header.Get("Authorization")
assert.Equal(t, c.ExpSig, actual)
assert.Equal(t, c.OrigURI, req.URL.Path)
assert.Equal(t, c.EscapedURI, req.URL.EscapedPath())
}
}
func TestStandaloneSign_RawPath(t *testing.T) {
creds := unit.Session.Config.Credentials
signer := v4.NewSigner(creds)
for _, c := range standaloneSignCases {
host := fmt.Sprintf("https://%s.%s.%s.amazonaws.com",
c.SubDomain, c.Region, c.Service)
req, err := http.NewRequest("GET", host, nil)
assert.NoError(t, err)
// URL.EscapedPath() will be used by the signer to get the
// escaped form of the request's URI path.
req.URL.Path = c.OrigURI
req.URL.RawPath = c.EscapedURI
req.URL.RawQuery = c.OrigQuery
_, err = signer.Sign(req, nil, c.Service, c.Region, time.Unix(0, 0))
assert.NoError(t, err)
actual := req.Header.Get("Authorization")
assert.Equal(t, c.ExpSig, actual)
assert.Equal(t, c.OrigURI, req.URL.Path)
assert.Equal(t, c.EscapedURI, req.URL.EscapedPath())
}
}

View file

@ -0,0 +1,120 @@
package v4_test
import (
"net/http"
"net/url"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var standaloneSignCases = []struct {
OrigURI string
OrigQuery string
Region, Service, SubDomain string
ExpSig string
EscapedURI string
}{
{
OrigURI: `/logs-*/_search`,
OrigQuery: `pretty=true`,
Region: "us-west-2", Service: "es", SubDomain: "hostname-clusterkey",
EscapedURI: `/logs-%2A/_search`,
ExpSig: `AWS4-HMAC-SHA256 Credential=AKID/19700101/us-west-2/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=79d0760751907af16f64a537c1242416dacf51204a7dd5284492d15577973b91`,
},
}
func TestPresignHandler(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, err := req.Presign(5 * time.Minute)
assert.NoError(t, err)
expectedHost := "bucket.s3.mock-region.amazonaws.com"
expectedDate := "19700101T000000Z"
expectedHeaders := "content-disposition;host;x-amz-acl"
expectedSig := "2d76a414208c0eac2a23ef9c834db9635ecd5a0fbb447a00ad191f82d854f55b"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
u, _ := url.Parse(urlstr)
urlQ := u.Query()
assert.Equal(t, expectedHost, u.Host)
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
assert.NotContains(t, urlstr, "+") // + encoded as %20
}
func TestPresignRequest(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, headers, err := req.PresignRequest(5 * time.Minute)
assert.NoError(t, err)
expectedHost := "bucket.s3.mock-region.amazonaws.com"
expectedDate := "19700101T000000Z"
expectedHeaders := "content-disposition;host;x-amz-acl;x-amz-content-sha256"
expectedSig := "a5b2b500dfbf2eab5b4f55bec3e3752e04536ea1d5c047aa93bc9f1130a72cd2"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
expectedHeaderMap := http.Header{
"x-amz-acl": []string{"public-read"},
"content-disposition": []string{"a+b c$d"},
"x-amz-content-sha256": []string{"UNSIGNED-PAYLOAD"},
}
u, _ := url.Parse(urlstr)
urlQ := u.Query()
assert.Equal(t, expectedHost, u.Host)
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
assert.Equal(t, expectedHeaderMap, headers)
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
assert.NotContains(t, urlstr, "+") // + encoded as %20
}
func TestStandaloneSign_CustomURIEscape(t *testing.T) {
var expectSig = `AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=6601e883cc6d23871fd6c2a394c5677ea2b8c82b04a6446786d64cd74f520967`
creds := unit.Session.Config.Credentials
signer := v4.NewSigner(creds, func(s *v4.Signer) {
s.DisableURIPathEscaping = true
})
host := "https://subdomain.us-east-1.es.amazonaws.com"
req, err := http.NewRequest("GET", host, nil)
assert.NoError(t, err)
req.URL.Path = `/log-*/_search`
req.URL.Opaque = "//subdomain.us-east-1.es.amazonaws.com/log-%2A/_search"
_, err = signer.Sign(req, nil, "es", "us-east-1", time.Unix(0, 0))
assert.NoError(t, err)
actual := req.Header.Get("Authorization")
assert.Equal(t, expectSig, actual)
}

View file

@ -0,0 +1,57 @@
package v4
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRuleCheckWhitelist(t *testing.T) {
w := whitelist{
mapRule{
"Cache-Control": struct{}{},
},
}
assert.True(t, w.IsValid("Cache-Control"))
assert.False(t, w.IsValid("Cache-"))
}
func TestRuleCheckBlacklist(t *testing.T) {
b := blacklist{
mapRule{
"Cache-Control": struct{}{},
},
}
assert.False(t, b.IsValid("Cache-Control"))
assert.True(t, b.IsValid("Cache-"))
}
func TestRuleCheckPattern(t *testing.T) {
p := patterns{"X-Amz-Meta-"}
assert.True(t, p.IsValid("X-Amz-Meta-"))
assert.True(t, p.IsValid("X-Amz-Meta-Star"))
assert.False(t, p.IsValid("Cache-"))
}
func TestRuleComplexWhitelist(t *testing.T) {
w := rules{
whitelist{
mapRule{
"Cache-Control": struct{}{},
},
},
patterns{"X-Amz-Meta-"},
}
r := rules{
inclusiveRules{patterns{"X-Amz-"}, blacklist{w}},
}
assert.True(t, r.IsValid("X-Amz-Blah"))
assert.False(t, r.IsValid("X-Amz-Meta-"))
assert.False(t, r.IsValid("X-Amz-Meta-Star"))
assert.False(t, r.IsValid("Cache-Control"))
}

View file

@ -0,0 +1,7 @@
package v4
// WithUnsignedPayload will enable and set the UnsignedPayload field to
// true of the signer.
func WithUnsignedPayload(v4 *Signer) {
v4.UnsignedPayload = true
}

View file

@ -194,6 +194,10 @@ type Signer struct {
// This value should only be used for testing. If it is nil the default
// time.Now will be used.
currentTimeFn func() time.Time
// UnsignedPayload will prevent signing of the payload. This will only
// work for services that have support for this.
UnsignedPayload bool
}
// NewSigner returns a Signer pointer configured with the credentials and optional
@ -227,6 +231,7 @@ type signingCtx struct {
isPresign bool
formattedTime string
formattedShortTime string
unsignedPayload bool
bodyDigest string
signedHeaders string
@ -317,6 +322,7 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
ServiceName: service,
Region: region,
DisableURIPathEscaping: v4.DisableURIPathEscaping,
unsignedPayload: v4.UnsignedPayload,
}
for key := range ctx.Query {
@ -409,7 +415,18 @@ var SignRequestHandler = request.NamedHandler{
func SignSDKRequest(req *request.Request) {
signSDKRequestWithCurrTime(req, time.Now)
}
func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time) {
// BuildNamedHandler will build a generic handler for signing.
func BuildNamedHandler(name string, opts ...func(*Signer)) request.NamedHandler {
return request.NamedHandler{
Name: name,
Fn: func(req *request.Request) {
signSDKRequestWithCurrTime(req, time.Now, opts...)
},
}
}
func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time, opts ...func(*Signer)) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Config.Credentials == credentials.AnonymousCredentials {
@ -441,6 +458,10 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
v4.DisableRequestBodyOverwrite = true
})
for _, opt := range opts {
opt(v4)
}
signingTime := req.Time
if !req.LastSignedAt.IsZero() {
signingTime = req.LastSignedAt
@ -634,14 +655,14 @@ func (ctx *signingCtx) buildSignature() {
func (ctx *signingCtx) buildBodyDigest() {
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if ctx.isPresign && ctx.ServiceName == "s3" {
if ctx.unsignedPayload || (ctx.isPresign && ctx.ServiceName == "s3") {
hash = "UNSIGNED-PAYLOAD"
} else if ctx.Body == nil {
hash = emptyStringSHA256
} else {
hash = hex.EncodeToString(makeSha256Reader(ctx.Body))
}
if ctx.ServiceName == "s3" || ctx.ServiceName == "glacier" {
if ctx.unsignedPayload || ctx.ServiceName == "s3" || ctx.ServiceName == "glacier" {
ctx.Request.Header.Set("X-Amz-Content-Sha256", hash)
}
}

View file

@ -0,0 +1,497 @@
package v4
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestStripExcessHeaders(t *testing.T) {
vals := []string{
"123",
"1 2 3",
" 1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2 ",
" 1 2 ",
}
expected := []string{
"123",
"1 2 3",
"1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2",
"1 2",
}
newVals := stripExcessSpaces(vals)
for i := 0; i < len(newVals); i++ {
assert.Equal(t, expected[i], newVals[i], "test: %d", i)
}
}
func buildRequest(serviceName, region, body string) (*http.Request, io.ReadSeeker) {
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
reader := strings.NewReader(body)
req, _ := http.NewRequest("POST", endpoint, reader)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Add("X-Amz-Target", "prefix.Operation")
req.Header.Add("Content-Type", "application/x-amz-json-1.0")
req.Header.Add("Content-Length", string(len(body)))
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
return req, reader
}
func buildSigner() Signer {
return Signer{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
}
}
func removeWS(text string) string {
text = strings.Replace(text, " ", "", -1)
text = strings.Replace(text, "\n", "", -1)
text = strings.Replace(text, "\t", "", -1)
return text
}
func assertEqual(t *testing.T, expected, given string) {
if removeWS(expected) != removeWS(given) {
t.Errorf("\nExpected: %s\nGiven: %s", expected, given)
}
}
func TestPresignRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
signer := buildSigner()
signer.Presign(req, body, "dynamodb", "us-east-1", 300*time.Second, time.Unix(0, 0))
expectedDate := "19700101T000000Z"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedSig := "ea7856749041f727690c580569738282e99c79355fe0d8f125d3b5535d2ece83"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
expectedTarget := "prefix.Operation"
q := req.URL.Query()
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
assert.Empty(t, q.Get("X-Amz-Meta-Other-Header"))
assert.Equal(t, expectedTarget, q.Get("X-Amz-Target"))
}
func TestPresignBodyWithArrayRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
signer := buildSigner()
signer.Presign(req, body, "dynamodb", "us-east-1", 300*time.Second, time.Unix(0, 0))
expectedDate := "19700101T000000Z"
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
expectedSig := "fef6002062400bbf526d70f1a6456abc0fb2e213fe1416012737eebd42a62924"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
expectedTarget := "prefix.Operation"
q := req.URL.Query()
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
assert.Empty(t, q.Get("X-Amz-Meta-Other-Header"))
assert.Equal(t, expectedTarget, q.Get("X-Amz-Target"))
}
func TestSignRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
signer := buildSigner()
signer.Sign(req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=ea766cabd2ec977d955a3c2bae1ae54f4515d70752f2207618396f20aa85bd21"
q := req.Header
assert.Equal(t, expectedSig, q.Get("Authorization"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignBodyS3(t *testing.T) {
req, body := buildRequest("s3", "us-east-1", "hello")
signer := buildSigner()
signer.Sign(req, body, "s3", "us-east-1", time.Now())
hash := req.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
}
func TestSignBodyGlacier(t *testing.T) {
req, body := buildRequest("glacier", "us-east-1", "hello")
signer := buildSigner()
signer.Sign(req, body, "glacier", "us-east-1", time.Now())
hash := req.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
}
func TestPresignEmptyBodyS3(t *testing.T) {
req, body := buildRequest("s3", "us-east-1", "hello")
signer := buildSigner()
signer.Presign(req, body, "s3", "us-east-1", 5*time.Minute, time.Now())
hash := req.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "UNSIGNED-PAYLOAD", hash)
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "hello")
req.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
signer := buildSigner()
signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
hash := req.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "PRECOMPUTED", hash)
}
func TestAnonymousCredentials(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{Credentials: credentials.AnonymousCredentials})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
SignSDKRequest(r)
urlQ := r.HTTPRequest.URL.Query()
assert.Empty(t, urlQ.Get("X-Amz-Signature"))
assert.Empty(t, urlQ.Get("X-Amz-Credential"))
assert.Empty(t, urlQ.Get("X-Amz-SignedHeaders"))
assert.Empty(t, urlQ.Get("X-Amz-Date"))
hQ := r.HTTPRequest.Header
assert.Empty(t, hQ.Get("Authorization"))
assert.Empty(t, hQ.Get("X-Amz-Date"))
}
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
SignSDKRequest(r)
sig := r.HTTPRequest.Header.Get("Authorization")
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate one second has passed so that signature's date changes
// when it is resigned.
return time.Now().Add(1 * time.Second)
})
assert.NotEqual(t, sig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
SignSDKRequest(r)
sig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate one second has passed so that signature's date changes
// when it is resigned.
return time.Now().Add(1 * time.Second)
})
assert.NotEqual(t, sig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
}
func TestResignRequestExpiredCreds(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
svc := awstesting.NewClient(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
SignSDKRequest(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
var origSignedHeaders string
for _, p := range strings.Split(querySig, ", ") {
if strings.HasPrefix(p, "SignedHeaders=") {
origSignedHeaders = p[len("SignedHeaders="):]
break
}
}
assert.NotEmpty(t, origSignedHeaders)
assert.NotContains(t, origSignedHeaders, "authorization")
origSignedAt := r.LastSignedAt
creds.Expire()
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate one second has passed so that signature's date changes
// when it is resigned.
return time.Now().Add(1 * time.Second)
})
updatedQuerySig := r.HTTPRequest.Header.Get("Authorization")
assert.NotEqual(t, querySig, updatedQuerySig)
var updatedSignedHeaders string
for _, p := range strings.Split(updatedQuerySig, ", ") {
if strings.HasPrefix(p, "SignedHeaders=") {
updatedSignedHeaders = p[len("SignedHeaders="):]
break
}
}
assert.NotEmpty(t, updatedSignedHeaders)
assert.NotContains(t, updatedQuerySig, "authorization")
assert.NotEqual(t, origSignedAt, r.LastSignedAt)
}
func TestPreResignRequestExpiredCreds(t *testing.T) {
provider := &credentials.StaticProvider{Value: credentials.Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "SESSION",
}}
creds := credentials.NewCredentials(provider)
svc := awstesting.NewClient(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
SignSDKRequest(r)
querySig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
signedHeaders := r.HTTPRequest.URL.Query().Get("X-Amz-SignedHeaders")
assert.NotEmpty(t, signedHeaders)
origSignedAt := r.LastSignedAt
creds.Expire()
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate the request occurred 15 minutes in the past
return time.Now().Add(-48 * time.Hour)
})
assert.NotEqual(t, querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
resignedHeaders := r.HTTPRequest.URL.Query().Get("X-Amz-SignedHeaders")
assert.Equal(t, signedHeaders, resignedHeaders)
assert.NotContains(t, signedHeaders, "x-amz-signedHeaders")
assert.NotEqual(t, origSignedAt, r.LastSignedAt)
}
func TestResignRequestExpiredRequest(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
svc := awstesting.NewClient(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
SignSDKRequest(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
origSignedAt := r.LastSignedAt
signSDKRequestWithCurrTime(r, func() time.Time {
// Simulate the request occurred 15 minutes in the past
return time.Now().Add(15 * time.Minute)
})
assert.NotEqual(t, querySig, r.HTTPRequest.Header.Get("Authorization"))
assert.NotEqual(t, origSignedAt, r.LastSignedAt)
}
func TestSignWithRequestBody(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
signer := NewSigner(creds)
expectBody := []byte("abc123")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
r.Body.Close()
assert.NoError(t, err)
assert.Equal(t, expectBody, b)
w.WriteHeader(http.StatusOK)
}))
req, err := http.NewRequest("POST", server.URL, nil)
_, err = signer.Sign(req, bytes.NewReader(expectBody), "service", "region", time.Now())
assert.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestSignWithRequestBody_Overwrite(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
signer := NewSigner(creds)
var expectBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
r.Body.Close()
assert.NoError(t, err)
assert.Equal(t, len(expectBody), len(b))
w.WriteHeader(http.StatusOK)
}))
req, err := http.NewRequest("GET", server.URL, strings.NewReader("invalid body"))
_, err = signer.Sign(req, nil, "service", "region", time.Now())
req.ContentLength = 0
assert.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestBuildCanonicalRequest(t *testing.T) {
req, body := buildRequest("dynamodb", "us-east-1", "{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
ctx := &signingCtx{
ServiceName: "dynamodb",
Region: "us-east-1",
Request: req,
Body: body,
Query: req.URL.Query(),
Time: time.Now(),
ExpireTime: 5 * time.Second,
}
ctx.buildCanonicalString()
expected := "https://example.org/bucket/key-._~,!@#$%^&*()?Foo=z&Foo=o&Foo=m&Foo=a"
assert.Equal(t, expected, ctx.Request.URL.String())
}
func TestSignWithBody_ReplaceRequestBody(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
req, seekerBody := buildRequest("dynamodb", "us-east-1", "{}")
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
s := NewSigner(creds)
origBody := req.Body
_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if req.Body == origBody {
t.Errorf("expeect request body to not be origBody")
}
if req.Body == nil {
t.Errorf("expect request body to be changed but was nil")
}
}
func TestSignWithBody_NoReplaceRequestBody(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
req, seekerBody := buildRequest("dynamodb", "us-east-1", "{}")
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
s := NewSigner(creds, func(signer *Signer) {
signer.DisableRequestBodyOverwrite = true
})
origBody := req.Body
_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if req.Body != origBody {
t.Errorf("expeect request body to not be chagned")
}
}
func BenchmarkPresignRequest(b *testing.B) {
signer := buildSigner()
req, body := buildRequest("dynamodb", "us-east-1", "{}")
for i := 0; i < b.N; i++ {
signer.Presign(req, body, "dynamodb", "us-east-1", 300*time.Second, time.Now())
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := buildSigner()
req, body := buildRequest("dynamodb", "us-east-1", "{}")
for i := 0; i < b.N; i++ {
signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
}
}
func BenchmarkStripExcessSpaces(b *testing.B) {
vals := []string{
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
`123 321 123 321`,
` 123 321 123 321 `,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
stripExcessSpaces(vals)
}
}

75
vendor/github.com/aws/aws-sdk-go/aws/types_test.go generated vendored Normal file
View file

@ -0,0 +1,75 @@
package aws
import (
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWriteAtBuffer(t *testing.T) {
b := &WriteAtBuffer{}
n, err := b.WriteAt([]byte{1}, 0)
assert.NoError(t, err)
assert.Equal(t, 1, n)
n, err = b.WriteAt([]byte{1, 1, 1}, 5)
assert.NoError(t, err)
assert.Equal(t, 3, n)
n, err = b.WriteAt([]byte{2}, 1)
assert.NoError(t, err)
assert.Equal(t, 1, n)
n, err = b.WriteAt([]byte{3}, 2)
assert.NoError(t, err)
assert.Equal(t, 1, n)
assert.Equal(t, []byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes())
}
func BenchmarkWriteAtBuffer(b *testing.B) {
buf := &WriteAtBuffer{}
r := rand.New(rand.NewSource(1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
to := r.Intn(10) * 4096
bs := make([]byte, to)
buf.WriteAt(bs, r.Int63n(10)*4096)
}
}
func BenchmarkWriteAtBufferOrderedWrites(b *testing.B) {
// test the performance of a WriteAtBuffer when written in an
// ordered fashion. This is similar to the behavior of the
// s3.Downloader, since downloads the first chunk of the file, then
// the second, and so on.
//
// This test simulates a 150MB file being written in 30 ordered 5MB chunks.
chunk := int64(5e6)
max := chunk * 30
// we'll write the same 5MB chunk every time
tmp := make([]byte, chunk)
for i := 0; i < b.N; i++ {
buf := &WriteAtBuffer{}
for i := int64(0); i < max; i += chunk {
buf.WriteAt(tmp, i)
}
}
}
func BenchmarkWriteAtBufferParallel(b *testing.B) {
buf := &WriteAtBuffer{}
r := rand.New(rand.NewSource(1))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
to := r.Intn(10) * 4096
bs := make([]byte, to)
buf.WriteAt(bs, r.Int63n(10)*4096)
}
})
}

12
vendor/github.com/aws/aws-sdk-go/aws/url.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
// +build go1.8
package aws
import "net/url"
// URLHostname will extract the Hostname without port from the URL value.
//
// Wrapper of net/url#URL.Hostname for backwards Go version compatibility.
func URLHostname(url *url.URL) string {
return url.Hostname()
}

Some files were not shown because too many files have changed in this diff Show more