From 20d50a559f077fe8d392957aa1e362041ab64f3c Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Tue, 9 Oct 2018 19:03:07 +0200 Subject: [PATCH] route53: fix challenge. (#665) --- .golangci.toml | 1 - acme/utils.go | 4 +- .../{testutil_test.go => mock_test.go} | 7 +- providers/dns/route53/route53.go | 184 +++++++++++------- .../dns/route53/route53_integration_test.go | 4 +- providers/dns/route53/route53_test.go | 131 +++++++++---- 6 files changed, 225 insertions(+), 106 deletions(-) rename providers/dns/route53/{testutil_test.go => mock_test.go} (84%) diff --git a/.golangci.toml b/.golangci.toml index 9bfee405..681218eb 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -33,7 +33,6 @@ max-per-linter = 0 max-same = 0 exclude = [ - "session.New is deprecated:", # providers/dns/route53/route53_integration_test.go | providers/dns/route53/route53_test.go "func (.+)disableAuthz(.) is unused", # acme/client.go#disableAuthz "type (.+)deactivateAuthMessage(.) is unused", # acme/messages.go#deactivateAuthMessage "(.)limitReader(.) - (.)numBytes(.) always receives (.)1048576(.)", # acme/crypto.go#limitReader diff --git a/acme/utils.go b/acme/utils.go index 2c68c4a3..f3160806 100644 --- a/acme/utils.go +++ b/acme/utils.go @@ -12,10 +12,10 @@ func WaitFor(timeout, interval time.Duration, f func() (bool, error)) error { log.Infof("Wait [timeout: %s, interval: %s]", timeout, interval) var lastErr string - timeup := time.After(timeout) + timeUp := time.After(timeout) for { select { - case <-timeup: + case <-timeUp: return fmt.Errorf("time limit exceeded: last error: %s", lastErr) default: } diff --git a/providers/dns/route53/testutil_test.go b/providers/dns/route53/mock_test.go similarity index 84% rename from providers/dns/route53/testutil_test.go rename to providers/dns/route53/mock_test.go index 22ad228c..79b0bbaf 100644 --- a/providers/dns/route53/testutil_test.go +++ b/providers/dns/route53/mock_test.go @@ -24,8 +24,11 @@ func newMockServer(t *testing.T, responses MockResponseMap) *httptest.Server { path := r.URL.Path resp, ok := responses[path] if !ok { - msg := fmt.Sprintf("Requested path not found in response map: %s", path) - require.FailNow(t, msg) + resp, ok = responses[r.RequestURI] + if !ok { + msg := fmt.Sprintf("Requested path not found in response map: %s", path) + require.FailNow(t, msg) + } } w.Header().Set("Content-Type", "application/xml") diff --git a/providers/dns/route53/route53.go b/providers/dns/route53/route53.go index 9416c6a4..1cace78e 100644 --- a/providers/dns/route53/route53.go +++ b/providers/dns/route53/route53.go @@ -44,17 +44,16 @@ type DNSProvider struct { config *Config } -// customRetryer implements the client.Retryer interface by composing the -// DefaultRetryer. It controls the logic for retrying recoverable request -// errors (e.g. when rate limits are exceeded). +// customRetryer implements the client.Retryer interface by composing the DefaultRetryer. +// It controls the logic for retrying recoverable request errors (e.g. when rate limits are exceeded). type customRetryer struct { client.DefaultRetryer } // RetryRules overwrites the DefaultRetryer's method. -// It uses a basic exponential backoff algorithm that returns an initial -// delay of ~400ms with an upper limit of ~30 seconds which should prevent -// causing a high number of consecutive throttling errors. +// It uses a basic exponential backoff algorithm: +// that returns an initial delay of ~400ms with an upper limit of ~30 seconds, +// which should prevent causing a high number of consecutive throttling errors. // For reference: Route 53 enforces an account-wide(!) 5req/s query limit. func (d customRetryer) RetryRules(r *request.Request) time.Duration { retryCount := r.RetryCount @@ -66,57 +65,81 @@ func (d customRetryer) RetryRules(r *request.Request) time.Duration { return time.Duration(delay) * time.Millisecond } -// NewDNSProvider returns a DNSProvider instance configured for the AWS -// Route 53 service. +// NewDNSProvider returns a DNSProvider instance configured for the AWS Route 53 service. // -// AWS Credentials are automatically detected in the following locations -// and prioritized in the following order: +// AWS Credentials are automatically detected in the following locations and prioritized in the following order: // 1. Environment variables: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, // AWS_REGION, [AWS_SESSION_TOKEN] // 2. Shared credentials file (defaults to ~/.aws/credentials) // 3. Amazon EC2 IAM role // -// If AWS_HOSTED_ZONE_ID is not set, Lego tries to determine the correct -// public hosted zone via the FQDN. +// If AWS_HOSTED_ZONE_ID is not set, Lego tries to determine the correct public hosted zone via the FQDN. // // See also: https://github.com/aws/aws-sdk-go/wiki/configuring-sdk func NewDNSProvider() (*DNSProvider, error) { return NewDNSProviderConfig(NewDefaultConfig()) } -// NewDNSProviderConfig takes a given config ans returns a custom configured -// DNSProvider instance +// NewDNSProviderConfig takes a given config ans returns a custom configured DNSProvider instance func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { if config == nil { return nil, errors.New("route53: the configuration of the Route53 DNS provider is nil") } - r := customRetryer{} - r.NumMaxRetries = config.MaxRetries - sessionCfg := request.WithRetryer(aws.NewConfig(), r) + retry := customRetryer{} + retry.NumMaxRetries = config.MaxRetries + sessionCfg := request.WithRetryer(aws.NewConfig(), retry) + sess, err := session.NewSessionWithOptions(session.Options{Config: *sessionCfg}) if err != nil { return nil, err } - cl := route53.New(sess) - return &DNSProvider{ - client: cl, - config: config, - }, nil + cl := route53.New(sess) + return &DNSProvider{client: cl, config: config}, nil } // Timeout returns the timeout and interval to use when checking for DNS // propagation. -func (r *DNSProvider) Timeout() (timeout, interval time.Duration) { - return r.config.PropagationTimeout, r.config.PollingInterval +func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { + return d.config.PropagationTimeout, d.config.PollingInterval } // Present creates a TXT record using the specified parameters -func (r *DNSProvider) Present(domain, token, keyAuth string) error { +func (d *DNSProvider) Present(domain, token, keyAuth string) error { fqdn, value, _ := acme.DNS01Record(domain, keyAuth) - err := r.changeRecord("UPSERT", fqdn, `"`+value+`"`, r.config.TTL) + hostedZoneID, err := d.getHostedZoneID(fqdn) + if err != nil { + return fmt.Errorf("route53: failed to determine hosted zone ID: %v", err) + } + + records, err := d.getExistingRecordSets(hostedZoneID, fqdn) + if err != nil { + return fmt.Errorf("route53: %v", err) + } + + realValue := `"` + value + `"` + + var found bool + for _, record := range records { + if aws.StringValue(record.Value) == realValue { + found = true + } + } + + if !found { + records = append(records, &route53.ResourceRecord{Value: aws.String(realValue)}) + } + + recordSet := &route53.ResourceRecordSet{ + Name: aws.String(fqdn), + Type: aws.String("TXT"), + TTL: aws.Int64(int64(d.config.TTL)), + ResourceRecords: records, + } + + err = d.changeRecord(route53.ChangeActionUpsert, hostedZoneID, recordSet) if err != nil { return fmt.Errorf("route53: %v", err) } @@ -124,61 +147,101 @@ func (r *DNSProvider) Present(domain, token, keyAuth string) error { } // CleanUp removes the TXT record matching the specified parameters -func (r *DNSProvider) CleanUp(domain, token, keyAuth string) error { - fqdn, value, _ := acme.DNS01Record(domain, keyAuth) +func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { + fqdn, _, _ := acme.DNS01Record(domain, keyAuth) - err := r.changeRecord("DELETE", fqdn, `"`+value+`"`, r.config.TTL) + hostedZoneID, err := d.getHostedZoneID(fqdn) + if err != nil { + return fmt.Errorf("failed to determine Route 53 hosted zone ID: %v", err) + } + + records, err := d.getExistingRecordSets(hostedZoneID, fqdn) + if err != nil { + return fmt.Errorf("route53: %v", err) + } + + if len(records) == 0 { + return nil + } + + recordSet := &route53.ResourceRecordSet{ + Name: aws.String(fqdn), + Type: aws.String("TXT"), + TTL: aws.Int64(int64(d.config.TTL)), + ResourceRecords: records, + } + + err = d.changeRecord(route53.ChangeActionDelete, hostedZoneID, recordSet) if err != nil { return fmt.Errorf("route53: %v", err) } return nil } -func (r *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error { - hostedZoneID, err := r.getHostedZoneID(fqdn) - if err != nil { - return fmt.Errorf("failed to determine Route 53 hosted zone ID: %v", err) - } - - recordSet := newTXTRecordSet(fqdn, value, ttl) - reqParams := &route53.ChangeResourceRecordSetsInput{ +func (d *DNSProvider) changeRecord(action, hostedZoneID string, recordSet *route53.ResourceRecordSet) error { + recordSetInput := &route53.ChangeResourceRecordSetsInput{ HostedZoneId: aws.String(hostedZoneID), ChangeBatch: &route53.ChangeBatch{ Comment: aws.String("Managed by Lego"), - Changes: []*route53.Change{ - { - Action: aws.String(action), - ResourceRecordSet: recordSet, - }, - }, + Changes: []*route53.Change{{ + Action: aws.String(action), + ResourceRecordSet: recordSet, + }}, }, } - resp, err := r.client.ChangeResourceRecordSets(reqParams) + resp, err := d.client.ChangeResourceRecordSets(recordSetInput) if err != nil { return fmt.Errorf("failed to change record set: %v", err) } - statusID := resp.ChangeInfo.Id + changeID := resp.ChangeInfo.Id - return acme.WaitFor(r.config.PropagationTimeout, r.config.PollingInterval, func() (bool, error) { - reqParams := &route53.GetChangeInput{ - Id: statusID, - } - resp, err := r.client.GetChange(reqParams) + return acme.WaitFor(d.config.PropagationTimeout, d.config.PollingInterval, func() (bool, error) { + reqParams := &route53.GetChangeInput{Id: changeID} + + resp, err := d.client.GetChange(reqParams) if err != nil { return false, fmt.Errorf("failed to query change status: %v", err) } + if aws.StringValue(resp.ChangeInfo.Status) == route53.ChangeStatusInsync { return true, nil } - return false, nil + return false, fmt.Errorf("unable to retrieve change: ID=%s", aws.StringValue(changeID)) }) } -func (r *DNSProvider) getHostedZoneID(fqdn string) (string, error) { - if r.config.HostedZoneID != "" { - return r.config.HostedZoneID, nil +func (d *DNSProvider) getExistingRecordSets(hostedZoneID string, fqdn string) ([]*route53.ResourceRecord, error) { + listInput := &route53.ListResourceRecordSetsInput{ + HostedZoneId: aws.String(hostedZoneID), + StartRecordName: aws.String(fqdn), + StartRecordType: aws.String("TXT"), + } + + recordSetsOutput, err := d.client.ListResourceRecordSets(listInput) + if err != nil { + return nil, err + } + + if recordSetsOutput == nil { + return nil, nil + } + + var records []*route53.ResourceRecord + + for _, recordSet := range recordSetsOutput.ResourceRecordSets { + if aws.StringValue(recordSet.Name) == fqdn { + records = append(records, recordSet.ResourceRecords...) + } + } + + return records, nil +} + +func (d *DNSProvider) getHostedZoneID(fqdn string) (string, error) { + if d.config.HostedZoneID != "" { + return d.config.HostedZoneID, nil } authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers) @@ -190,7 +253,7 @@ func (r *DNSProvider) getHostedZoneID(fqdn string) (string, error) { reqParams := &route53.ListHostedZonesByNameInput{ DNSName: aws.String(acme.UnFqdn(authZone)), } - resp, err := r.client.ListHostedZonesByName(reqParams) + resp, err := d.client.ListHostedZonesByName(reqParams) if err != nil { return "", err } @@ -214,14 +277,3 @@ func (r *DNSProvider) getHostedZoneID(fqdn string) (string, error) { return hostedZoneID, nil } - -func newTXTRecordSet(fqdn, value string, ttl int) *route53.ResourceRecordSet { - return &route53.ResourceRecordSet{ - Name: aws.String(fqdn), - Type: aws.String("TXT"), - TTL: aws.Int64(int64(ttl)), - ResourceRecords: []*route53.ResourceRecord{ - {Value: aws.String(value)}, - }, - } -} diff --git a/providers/dns/route53/route53_integration_test.go b/providers/dns/route53/route53_integration_test.go index 6d0d9e89..68c076fb 100644 --- a/providers/dns/route53/route53_integration_test.go +++ b/providers/dns/route53/route53_integration_test.go @@ -26,7 +26,9 @@ func TestRoute53TTL(t *testing.T) { // we need a separate R53 client here as the one in the DNS provider is unexported. fqdn := "_acme-challenge." + r53Domain + "." - svc := route53.New(session.New()) + sess, err := session.NewSession() + require.NoError(t, err) + svc := route53.New(sess) defer func() { errC := provider.CleanUp(r53Domain, "foo", "bar") diff --git a/providers/dns/route53/route53_test.go b/providers/dns/route53/route53_test.go index 9f278b04..4f2a0410 100644 --- a/providers/dns/route53/route53_test.go +++ b/providers/dns/route53/route53_test.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/route53" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -47,7 +48,18 @@ func restoreEnv() { os.Setenv("AWS_TTL", r53AwsTTL) os.Setenv("AWS_PROPAGATION_TIMEOUT", r53AwsPropagationTimeout) os.Setenv("AWS_POLLING_INTERVAL", r53AwsPollingInterval) +} +func cleanEnv() { + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + os.Unsetenv("AWS_REGION") + os.Unsetenv("AWS_HOSTED_ZONE_ID") + + os.Unsetenv("AWS_MAX_RETRIES") + os.Unsetenv("AWS_TTL") + os.Unsetenv("AWS_PROPAGATION_TIMEOUT") + os.Unsetenv("AWS_POLLING_INTERVAL") } func makeRoute53Provider(ts *httptest.Server) *DNSProvider { @@ -58,75 +70,126 @@ func makeRoute53Provider(ts *httptest.Server) *DNSProvider { MaxRetries: aws.Int(1), } - client := route53.New(session.New(config)) + sess, err := session.NewSession(config) + if err != nil { + panic(err) + } + client := route53.New(sess) cfg := NewDefaultConfig() return &DNSProvider{client: client, config: cfg} } -func TestCredentialsFromEnv(t *testing.T) { +func Test_loadCredentials_FromEnv(t *testing.T) { defer restoreEnv() os.Setenv("AWS_ACCESS_KEY_ID", "123") - os.Setenv("AWS_SECRET_ACCESS_KEY", "123") + os.Setenv("AWS_SECRET_ACCESS_KEY", "456") os.Setenv("AWS_REGION", "us-east-1") config := &aws.Config{ CredentialsChainVerboseErrors: aws.Bool(true), } - sess := session.New(config) - _, err := sess.Config.Credentials.Get() + sess, err := session.NewSession(config) + require.NoError(t, err) + + value, err := sess.Config.Credentials.Get() assert.NoError(t, err, "Expected credentials to be set from environment") + + expected := credentials.Value{ + AccessKeyID: "123", + SecretAccessKey: "456", + SessionToken: "", + ProviderName: "EnvConfigCredentials", + } + assert.Equal(t, expected, value) } -func TestRegionFromEnv(t *testing.T) { +func Test_loadRegion_FromEnv(t *testing.T) { defer restoreEnv() - os.Setenv("AWS_REGION", "us-east-1") + os.Setenv("AWS_REGION", route53.CloudWatchRegionUsEast1) - sess := session.New(aws.NewConfig()) - assert.Equal(t, "us-east-1", aws.StringValue(sess.Config.Region), "Expected Region to be set from environment") + sess, err := session.NewSession(aws.NewConfig()) + require.NoError(t, err) + + region := aws.StringValue(sess.Config.Region) + assert.Equal(t, route53.CloudWatchRegionUsEast1, region, "Region") } -func TestHostedZoneIDFromEnv(t *testing.T) { +func Test_getHostedZoneID_FromEnv(t *testing.T) { defer restoreEnv() - const testZoneID = "testzoneid" - os.Setenv("AWS_HOSTED_ZONE_ID", testZoneID) + expectedZoneID := "zoneID" + + os.Setenv("AWS_HOSTED_ZONE_ID", expectedZoneID) provider, err := NewDNSProvider() - assert.NoError(t, err, "Expected no error constructing DNSProvider") + assert.NoError(t, err) - fqdn, err := provider.getHostedZoneID("whatever") - assert.NoError(t, err, "Expected FQDN to be resolved to environment variable value") + hostedZoneID, err := provider.getHostedZoneID("whatever") + assert.NoError(t, err, "HostedZoneID") - assert.Equal(t, testZoneID, fqdn) + assert.Equal(t, expectedZoneID, hostedZoneID) } -func TestConfigFromEnv(t *testing.T) { +func TestNewDefaultConfig(t *testing.T) { defer restoreEnv() - config := NewDefaultConfig() - assert.Equal(t, config.TTL, 10, "Expected TTL to be use the default") + testCases := []struct { + desc string + envVars map[string]string + expected *Config + }{ + { + desc: "default configuration", + expected: &Config{ + MaxRetries: 5, + TTL: 10, + PropagationTimeout: 2 * time.Minute, + PollingInterval: 4 * time.Second, + }, + }, + { + desc: "", + envVars: map[string]string{ + "AWS_MAX_RETRIES": "10", + "AWS_TTL": "99", + "AWS_PROPAGATION_TIMEOUT": "60", + "AWS_POLLING_INTERVAL": "60", + "AWS_HOSTED_ZONE_ID": "abc123", + }, + expected: &Config{ + MaxRetries: 10, + TTL: 99, + PropagationTimeout: 60 * time.Second, + PollingInterval: 60 * time.Second, + HostedZoneID: "abc123", + }, + }, + } - os.Setenv("AWS_MAX_RETRIES", "10") - os.Setenv("AWS_TTL", "99") - os.Setenv("AWS_PROPAGATION_TIMEOUT", "60") - os.Setenv("AWS_POLLING_INTERVAL", "60") - const zoneID = "abc123" - os.Setenv("AWS_HOSTED_ZONE_ID", zoneID) + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + cleanEnv() + for key, value := range test.envVars { + os.Setenv(key, value) + } - config = NewDefaultConfig() - assert.Equal(t, config.MaxRetries, 10, "Expected PropagationTimeout to be configured from the environment") - assert.Equal(t, config.TTL, 99, "Expected TTL to be configured from the environment") - assert.Equal(t, config.PropagationTimeout, time.Second*60, "Expected PropagationTimeout to be configured from the environment") - assert.Equal(t, config.PollingInterval, time.Second*60, "Expected PollingInterval to be configured from the environment") - assert.Equal(t, config.HostedZoneID, zoneID, "Expected HostedZoneID to be configured from the environment") + config := NewDefaultConfig() + + assert.Equal(t, test.expected, config) + }) + } } func TestRoute53Present(t *testing.T) { mockResponses := MockResponseMap{ - "/2013-04-01/hostedzonesbyname": MockResponse{StatusCode: 200, Body: ListHostedZonesByNameResponse}, - "/2013-04-01/hostedzone/ABCDEFG/rrset/": MockResponse{StatusCode: 200, Body: ChangeResourceRecordSetsResponse}, - "/2013-04-01/change/123456": MockResponse{StatusCode: 200, Body: GetChangeResponse}, + "/2013-04-01/hostedzonesbyname": {StatusCode: 200, Body: ListHostedZonesByNameResponse}, + "/2013-04-01/hostedzone/ABCDEFG/rrset/": {StatusCode: 200, Body: ChangeResourceRecordSetsResponse}, + "/2013-04-01/change/123456": {StatusCode: 200, Body: GetChangeResponse}, + "/2013-04-01/hostedzone/ABCDEFG/rrset?name=_acme-challenge.example.com.&type=TXT": { + StatusCode: 200, + Body: "", + }, } ts := newMockServer(t, mockResponses)