diff --git a/providers/dns/route53/route53.go b/providers/dns/route53/route53.go index daa81aeb..01293166 100644 --- a/providers/dns/route53/route53.go +++ b/providers/dns/route53/route53.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" @@ -28,15 +29,22 @@ const ( EnvRegion = envNamespace + "REGION" EnvHostedZoneID = envNamespace + "HOSTED_ZONE_ID" EnvMaxRetries = envNamespace + "MAX_RETRIES" + EnvAssumeRoleArn = envNamespace + "ASSUME_ROLE_ARN" EnvTTL = envNamespace + "TTL" EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT" EnvPollingInterval = envNamespace + "POLLING_INTERVAL" - EnvAssumeRoleArn = envNamespace + "ASSUME_ROLE_ARN" ) // Config is used to configure the creation of the DNSProvider. type Config struct { + // Static credential chain. + // These are not set via environment for the time being and are only used if they are explicitly provided. + AccessKeyID string + SecretAccessKey string + SessionToken string + Region string + HostedZoneID string MaxRetries int AssumeRoleArn string @@ -301,10 +309,23 @@ func (d *DNSProvider) getHostedZoneID(fqdn string) (string, error) { } func createSession(config *Config) (*session.Session, error) { + if err := createSessionCheckParams(config); err != nil { + return nil, err + } + retry := customRetryer{} retry.NumMaxRetries = config.MaxRetries - sessionCfg := request.WithRetryer(aws.NewConfig(), retry) + awsConfig := aws.NewConfig() + if config.AccessKeyID != "" && config.SecretAccessKey != "" { + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(config.AccessKeyID, config.SecretAccessKey, config.SessionToken)) + } + + if config.Region != "" { + awsConfig = awsConfig.WithRegion(config.Region) + } + + sessionCfg := request.WithRetryer(awsConfig, retry) sess, err := session.NewSessionWithOptions(session.Options{Config: *sessionCfg}) if err != nil { @@ -320,3 +341,19 @@ func createSession(config *Config) (*session.Session, error) { Credentials: stscreds.NewCredentials(sess, config.AssumeRoleArn), }) } + +func createSessionCheckParams(config *Config) error { + if config == nil { + return errors.New("config is nil") + } + + switch { + case config.SessionToken != "" && config.AccessKeyID == "" && config.SecretAccessKey == "": + return errors.New("SessionToken must be supplied with AccessKeyID and SecretAccessKey") + + case config.AccessKeyID == "" && config.SecretAccessKey != "" || config.AccessKeyID != "" && config.SecretAccessKey == "": + return errors.New("AccessKeyID and SecretAccessKey must be supplied together") + } + + return nil +} diff --git a/providers/dns/route53/route53_test.go b/providers/dns/route53/route53_test.go index 5f115695..f1dd1a29 100644 --- a/providers/dns/route53/route53_test.go +++ b/providers/dns/route53/route53_test.go @@ -177,3 +177,133 @@ func TestDNSProvider_Present(t *testing.T) { err := provider.Present(domain, "", keyAuth) require.NoError(t, err, "Expected Present to return no error") } + +func TestCreateSession(t *testing.T) { + testCases := []struct { + desc string + env map[string]string + config *Config + wantCreds credentials.Value + wantDefaultChain bool + wantRegion string + wantErr string + }{ + { + desc: "config is nil", + wantErr: "config is nil", + }, + { + desc: "session token without access key id or secret access key", + config: &Config{SessionToken: "foo"}, + wantErr: "SessionToken must be supplied with AccessKeyID and SecretAccessKey", + }, + { + desc: "access key id without secret access key", + config: &Config{AccessKeyID: "foo"}, + wantErr: "AccessKeyID and SecretAccessKey must be supplied together", + }, + { + desc: "access key id without secret access key", + config: &Config{SecretAccessKey: "foo"}, + wantErr: "AccessKeyID and SecretAccessKey must be supplied together", + }, + { + desc: "credentials from default chain", + config: &Config{}, + wantDefaultChain: true, + }, + { + desc: "static credentials", + config: &Config{ + AccessKeyID: "one", + SecretAccessKey: "two", + }, + wantCreds: credentials.Value{ + AccessKeyID: "one", + SecretAccessKey: "two", + SessionToken: "", + ProviderName: credentials.StaticProviderName, + }, + }, + { + desc: "static credentials with session token", + config: &Config{ + AccessKeyID: "one", + SecretAccessKey: "two", + SessionToken: "three", + }, + wantCreds: credentials.Value{ + AccessKeyID: "one", + SecretAccessKey: "two", + SessionToken: "three", + ProviderName: credentials.StaticProviderName, + }, + }, + { + desc: "region from env", + config: &Config{}, + env: map[string]string{ + "AWS_REGION": "foo", + }, + wantDefaultChain: true, + wantRegion: "foo", + }, + { + desc: "static region", + config: &Config{ + Region: "one", + }, + env: map[string]string{ + "AWS_REGION": "foo", + }, + wantDefaultChain: true, + wantRegion: "one", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + defer envTest.RestoreEnv() + envTest.ClearEnv() + + envTest.Apply(test.env) + + sess, err := createSession(test.config) + requireErr(t, err, test.wantErr) + + if err != nil { + return + } + + gotCreds, err := sess.Config.Credentials.Get() + + if test.wantDefaultChain { + assert.NotEqual(t, credentials.StaticProviderName, gotCreds.ProviderName) + } else { + require.NoError(t, err) + assert.Equal(t, test.wantCreds, gotCreds) + } + + if test.wantRegion != "" { + assert.Equal(t, test.wantRegion, aws.StringValue(sess.Config.Region)) + } + }) + } +} + +func requireErr(t *testing.T, err error, wantErr string) { + t.Helper() + + switch { + case err != nil && wantErr == "": + // force the assertion error. + require.NoError(t, err) + + case err == nil && wantErr != "": + // force the assertion error. + require.EqualError(t, err, wantErr) + + case err != nil && wantErr != "": + require.EqualError(t, err, wantErr) + } +}