forked from TrueCloudLab/lego
route53: Allow static credentials to be supplied (#1746)
Co-authored-by: Fernandez Ludovic <ldez@users.noreply.github.com>
This commit is contained in:
parent
07d957fdc1
commit
0122506c23
2 changed files with 169 additions and 2 deletions
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
|
@ -28,15 +29,22 @@ const (
|
|||
EnvRegion = envNamespace + "REGION"
|
||||
EnvHostedZoneID = envNamespace + "HOSTED_ZONE_ID"
|
||||
EnvMaxRetries = envNamespace + "MAX_RETRIES"
|
||||
EnvAssumeRoleArn = envNamespace + "ASSUME_ROLE_ARN"
|
||||
|
||||
EnvTTL = envNamespace + "TTL"
|
||||
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
|
||||
EnvPollingInterval = envNamespace + "POLLING_INTERVAL"
|
||||
EnvAssumeRoleArn = envNamespace + "ASSUME_ROLE_ARN"
|
||||
)
|
||||
|
||||
// Config is used to configure the creation of the DNSProvider.
|
||||
type Config struct {
|
||||
// Static credential chain.
|
||||
// These are not set via environment for the time being and are only used if they are explicitly provided.
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
SessionToken string
|
||||
Region string
|
||||
|
||||
HostedZoneID string
|
||||
MaxRetries int
|
||||
AssumeRoleArn string
|
||||
|
@ -301,10 +309,23 @@ func (d *DNSProvider) getHostedZoneID(fqdn string) (string, error) {
|
|||
}
|
||||
|
||||
func createSession(config *Config) (*session.Session, error) {
|
||||
if err := createSessionCheckParams(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retry := customRetryer{}
|
||||
retry.NumMaxRetries = config.MaxRetries
|
||||
|
||||
sessionCfg := request.WithRetryer(aws.NewConfig(), retry)
|
||||
awsConfig := aws.NewConfig()
|
||||
if config.AccessKeyID != "" && config.SecretAccessKey != "" {
|
||||
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(config.AccessKeyID, config.SecretAccessKey, config.SessionToken))
|
||||
}
|
||||
|
||||
if config.Region != "" {
|
||||
awsConfig = awsConfig.WithRegion(config.Region)
|
||||
}
|
||||
|
||||
sessionCfg := request.WithRetryer(awsConfig, retry)
|
||||
|
||||
sess, err := session.NewSessionWithOptions(session.Options{Config: *sessionCfg})
|
||||
if err != nil {
|
||||
|
@ -320,3 +341,19 @@ func createSession(config *Config) (*session.Session, error) {
|
|||
Credentials: stscreds.NewCredentials(sess, config.AssumeRoleArn),
|
||||
})
|
||||
}
|
||||
|
||||
func createSessionCheckParams(config *Config) error {
|
||||
if config == nil {
|
||||
return errors.New("config is nil")
|
||||
}
|
||||
|
||||
switch {
|
||||
case config.SessionToken != "" && config.AccessKeyID == "" && config.SecretAccessKey == "":
|
||||
return errors.New("SessionToken must be supplied with AccessKeyID and SecretAccessKey")
|
||||
|
||||
case config.AccessKeyID == "" && config.SecretAccessKey != "" || config.AccessKeyID != "" && config.SecretAccessKey == "":
|
||||
return errors.New("AccessKeyID and SecretAccessKey must be supplied together")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -177,3 +177,133 @@ func TestDNSProvider_Present(t *testing.T) {
|
|||
err := provider.Present(domain, "", keyAuth)
|
||||
require.NoError(t, err, "Expected Present to return no error")
|
||||
}
|
||||
|
||||
func TestCreateSession(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
env map[string]string
|
||||
config *Config
|
||||
wantCreds credentials.Value
|
||||
wantDefaultChain bool
|
||||
wantRegion string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
desc: "config is nil",
|
||||
wantErr: "config is nil",
|
||||
},
|
||||
{
|
||||
desc: "session token without access key id or secret access key",
|
||||
config: &Config{SessionToken: "foo"},
|
||||
wantErr: "SessionToken must be supplied with AccessKeyID and SecretAccessKey",
|
||||
},
|
||||
{
|
||||
desc: "access key id without secret access key",
|
||||
config: &Config{AccessKeyID: "foo"},
|
||||
wantErr: "AccessKeyID and SecretAccessKey must be supplied together",
|
||||
},
|
||||
{
|
||||
desc: "access key id without secret access key",
|
||||
config: &Config{SecretAccessKey: "foo"},
|
||||
wantErr: "AccessKeyID and SecretAccessKey must be supplied together",
|
||||
},
|
||||
{
|
||||
desc: "credentials from default chain",
|
||||
config: &Config{},
|
||||
wantDefaultChain: true,
|
||||
},
|
||||
{
|
||||
desc: "static credentials",
|
||||
config: &Config{
|
||||
AccessKeyID: "one",
|
||||
SecretAccessKey: "two",
|
||||
},
|
||||
wantCreds: credentials.Value{
|
||||
AccessKeyID: "one",
|
||||
SecretAccessKey: "two",
|
||||
SessionToken: "",
|
||||
ProviderName: credentials.StaticProviderName,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "static credentials with session token",
|
||||
config: &Config{
|
||||
AccessKeyID: "one",
|
||||
SecretAccessKey: "two",
|
||||
SessionToken: "three",
|
||||
},
|
||||
wantCreds: credentials.Value{
|
||||
AccessKeyID: "one",
|
||||
SecretAccessKey: "two",
|
||||
SessionToken: "three",
|
||||
ProviderName: credentials.StaticProviderName,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "region from env",
|
||||
config: &Config{},
|
||||
env: map[string]string{
|
||||
"AWS_REGION": "foo",
|
||||
},
|
||||
wantDefaultChain: true,
|
||||
wantRegion: "foo",
|
||||
},
|
||||
{
|
||||
desc: "static region",
|
||||
config: &Config{
|
||||
Region: "one",
|
||||
},
|
||||
env: map[string]string{
|
||||
"AWS_REGION": "foo",
|
||||
},
|
||||
wantDefaultChain: true,
|
||||
wantRegion: "one",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
defer envTest.RestoreEnv()
|
||||
envTest.ClearEnv()
|
||||
|
||||
envTest.Apply(test.env)
|
||||
|
||||
sess, err := createSession(test.config)
|
||||
requireErr(t, err, test.wantErr)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
gotCreds, err := sess.Config.Credentials.Get()
|
||||
|
||||
if test.wantDefaultChain {
|
||||
assert.NotEqual(t, credentials.StaticProviderName, gotCreds.ProviderName)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.wantCreds, gotCreds)
|
||||
}
|
||||
|
||||
if test.wantRegion != "" {
|
||||
assert.Equal(t, test.wantRegion, aws.StringValue(sess.Config.Region))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func requireErr(t *testing.T, err error, wantErr string) {
|
||||
t.Helper()
|
||||
|
||||
switch {
|
||||
case err != nil && wantErr == "":
|
||||
// force the assertion error.
|
||||
require.NoError(t, err)
|
||||
|
||||
case err == nil && wantErr != "":
|
||||
// force the assertion error.
|
||||
require.EqualError(t, err, wantErr)
|
||||
|
||||
case err != nil && wantErr != "":
|
||||
require.EqualError(t, err, wantErr)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue