route53: Allow static credentials to be supplied (#1746)

Co-authored-by: Fernandez Ludovic <ldez@users.noreply.github.com>
This commit is contained in:
Chris Marchesi 2023-01-15 14:50:35 -08:00 committed by GitHub
parent 07d957fdc1
commit 0122506c23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 169 additions and 2 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
@ -28,15 +29,22 @@ const (
EnvRegion = envNamespace + "REGION"
EnvHostedZoneID = envNamespace + "HOSTED_ZONE_ID"
EnvMaxRetries = envNamespace + "MAX_RETRIES"
EnvAssumeRoleArn = envNamespace + "ASSUME_ROLE_ARN"
EnvTTL = envNamespace + "TTL"
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
EnvPollingInterval = envNamespace + "POLLING_INTERVAL"
EnvAssumeRoleArn = envNamespace + "ASSUME_ROLE_ARN"
)
// Config is used to configure the creation of the DNSProvider.
type Config struct {
// Static credential chain.
// These are not set via environment for the time being and are only used if they are explicitly provided.
AccessKeyID string
SecretAccessKey string
SessionToken string
Region string
HostedZoneID string
MaxRetries int
AssumeRoleArn string
@ -301,10 +309,23 @@ func (d *DNSProvider) getHostedZoneID(fqdn string) (string, error) {
}
func createSession(config *Config) (*session.Session, error) {
if err := createSessionCheckParams(config); err != nil {
return nil, err
}
retry := customRetryer{}
retry.NumMaxRetries = config.MaxRetries
sessionCfg := request.WithRetryer(aws.NewConfig(), retry)
awsConfig := aws.NewConfig()
if config.AccessKeyID != "" && config.SecretAccessKey != "" {
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(config.AccessKeyID, config.SecretAccessKey, config.SessionToken))
}
if config.Region != "" {
awsConfig = awsConfig.WithRegion(config.Region)
}
sessionCfg := request.WithRetryer(awsConfig, retry)
sess, err := session.NewSessionWithOptions(session.Options{Config: *sessionCfg})
if err != nil {
@ -320,3 +341,19 @@ func createSession(config *Config) (*session.Session, error) {
Credentials: stscreds.NewCredentials(sess, config.AssumeRoleArn),
})
}
func createSessionCheckParams(config *Config) error {
if config == nil {
return errors.New("config is nil")
}
switch {
case config.SessionToken != "" && config.AccessKeyID == "" && config.SecretAccessKey == "":
return errors.New("SessionToken must be supplied with AccessKeyID and SecretAccessKey")
case config.AccessKeyID == "" && config.SecretAccessKey != "" || config.AccessKeyID != "" && config.SecretAccessKey == "":
return errors.New("AccessKeyID and SecretAccessKey must be supplied together")
}
return nil
}

View file

@ -177,3 +177,133 @@ func TestDNSProvider_Present(t *testing.T) {
err := provider.Present(domain, "", keyAuth)
require.NoError(t, err, "Expected Present to return no error")
}
func TestCreateSession(t *testing.T) {
testCases := []struct {
desc string
env map[string]string
config *Config
wantCreds credentials.Value
wantDefaultChain bool
wantRegion string
wantErr string
}{
{
desc: "config is nil",
wantErr: "config is nil",
},
{
desc: "session token without access key id or secret access key",
config: &Config{SessionToken: "foo"},
wantErr: "SessionToken must be supplied with AccessKeyID and SecretAccessKey",
},
{
desc: "access key id without secret access key",
config: &Config{AccessKeyID: "foo"},
wantErr: "AccessKeyID and SecretAccessKey must be supplied together",
},
{
desc: "access key id without secret access key",
config: &Config{SecretAccessKey: "foo"},
wantErr: "AccessKeyID and SecretAccessKey must be supplied together",
},
{
desc: "credentials from default chain",
config: &Config{},
wantDefaultChain: true,
},
{
desc: "static credentials",
config: &Config{
AccessKeyID: "one",
SecretAccessKey: "two",
},
wantCreds: credentials.Value{
AccessKeyID: "one",
SecretAccessKey: "two",
SessionToken: "",
ProviderName: credentials.StaticProviderName,
},
},
{
desc: "static credentials with session token",
config: &Config{
AccessKeyID: "one",
SecretAccessKey: "two",
SessionToken: "three",
},
wantCreds: credentials.Value{
AccessKeyID: "one",
SecretAccessKey: "two",
SessionToken: "three",
ProviderName: credentials.StaticProviderName,
},
},
{
desc: "region from env",
config: &Config{},
env: map[string]string{
"AWS_REGION": "foo",
},
wantDefaultChain: true,
wantRegion: "foo",
},
{
desc: "static region",
config: &Config{
Region: "one",
},
env: map[string]string{
"AWS_REGION": "foo",
},
wantDefaultChain: true,
wantRegion: "one",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
defer envTest.RestoreEnv()
envTest.ClearEnv()
envTest.Apply(test.env)
sess, err := createSession(test.config)
requireErr(t, err, test.wantErr)
if err != nil {
return
}
gotCreds, err := sess.Config.Credentials.Get()
if test.wantDefaultChain {
assert.NotEqual(t, credentials.StaticProviderName, gotCreds.ProviderName)
} else {
require.NoError(t, err)
assert.Equal(t, test.wantCreds, gotCreds)
}
if test.wantRegion != "" {
assert.Equal(t, test.wantRegion, aws.StringValue(sess.Config.Region))
}
})
}
}
func requireErr(t *testing.T, err error, wantErr string) {
t.Helper()
switch {
case err != nil && wantErr == "":
// force the assertion error.
require.NoError(t, err)
case err == nil && wantErr != "":
// force the assertion error.
require.EqualError(t, err, wantErr)
case err != nil && wantErr != "":
require.EqualError(t, err, wantErr)
}
}